From 9044609be3ea78c650420533e7f6f40b83cedd99 Mon Sep 17 00:00:00 2001 From: jrandolf <101637635+jrandolf@users.noreply.github.com> Date: Mon, 13 Mar 2023 16:11:16 +0100 Subject: [PATCH] fix: sort elements based on selector matching algorithm (#9836) --- .../src/injected/PQuerySelector.ts | 79 ++++++++++++++++--- .../src/util/AsyncIterableUtil.ts | 6 +- test/src/queryhandler.spec.ts | 56 ++++++++++++- 3 files changed, 127 insertions(+), 14 deletions(-) diff --git a/packages/puppeteer-core/src/injected/PQuerySelector.ts b/packages/puppeteer-core/src/injected/PQuerySelector.ts index 07b30245a25..8355c327a9f 100644 --- a/packages/puppeteer-core/src/injected/PQuerySelector.ts +++ b/packages/puppeteer-core/src/injected/PQuerySelector.ts @@ -170,6 +170,67 @@ class PQueryEngine { } } +class DepthCalculator { + #cache = new Map(); + + calculate(node: Node, depth: number[] = []): number[] { + if (node instanceof Document) { + return depth; + } + if (node instanceof ShadowRoot) { + node = node.host; + } + + const cachedDepth = this.#cache.get(node); + if (cachedDepth) { + return [...cachedDepth, ...depth]; + } + + let index = 0; + for ( + let prevSibling = node.previousSibling; + prevSibling; + prevSibling = prevSibling.previousSibling + ) { + ++index; + } + + const value = this.calculate(node.parentNode as Node, [index]); + this.#cache.set(node, value); + return [...value, ...depth]; + } +} + +const compareDepths = (a: number[], b: number[]): -1 | 0 | 1 => { + if (a.length + b.length === 0) { + return 0; + } + const [i = Infinity, ...otherA] = a; + const [j = Infinity, ...otherB] = b; + if (i === j) { + return compareDepths(otherA, otherB); + } + return i < j ? 1 : -1; +}; + +const domSort = async function* (elements: AwaitableIterable) { + const results = new Set(); + for await (const element of elements) { + results.add(element); + } + const calculator = new DepthCalculator(); + yield* [...results.values()] + .map(result => { + return [result, calculator.calculate(result)] as const; + }) + .sort(([, a], [, b]) => { + return compareDepths(a, b); + }) + .map(([result]) => { + return result; + }); +}; + type QueryableNode = { querySelectorAll: typeof Document.prototype.querySelectorAll; }; @@ -179,7 +240,7 @@ type QueryableNode = { * * @internal */ -export const pQuerySelectorAll = async function* ( +export const pQuerySelectorAll = function ( root: Node, selector: string ): AwaitableIterable { @@ -195,10 +256,8 @@ export const pQuerySelectorAll = async function* ( } if (isPureCSS) { - yield* (root as unknown as QueryableNode).querySelectorAll(selector); - return; + return (root as unknown as QueryableNode).querySelectorAll(selector); } - // If there are any empty elements, then this implies the selector has // contiguous combinators (e.g. `>>> >>>>`) or starts/ends with one which we // treat as illegal, similar to existing behavior. @@ -221,11 +280,13 @@ export const pQuerySelectorAll = async function* ( ); } - for (const selectorParts of selectors) { - const query = new PQueryEngine(root, selector, selectorParts); - query.run(); - yield* query.elements; - } + return domSort( + AsyncIterableUtil.flatMap(selectors, selectorParts => { + const query = new PQueryEngine(root, selector, selectorParts); + query.run(); + return query.elements; + }) + ); }; /** diff --git a/packages/puppeteer-core/src/util/AsyncIterableUtil.ts b/packages/puppeteer-core/src/util/AsyncIterableUtil.ts index a35e4794535..5b06b3ab30c 100644 --- a/packages/puppeteer-core/src/util/AsyncIterableUtil.ts +++ b/packages/puppeteer-core/src/util/AsyncIterableUtil.ts @@ -28,10 +28,10 @@ export class AsyncIterableUtil { } } - static async *flatMap( + static async *flatMap( iterable: AwaitableIterable, - map: (item: T) => AwaitableIterable - ): AsyncIterable { + map: (item: T) => AwaitableIterable + ): AsyncIterable { for await (const value of iterable) { yield* map(value); } diff --git a/test/src/queryhandler.spec.ts b/test/src/queryhandler.spec.ts index e4568af945b..0af241d1719 100644 --- a/test/src/queryhandler.spec.ts +++ b/test/src/queryhandler.spec.ts @@ -359,7 +359,9 @@ describe('Query handler tests', function () { describe('P selectors', () => { beforeEach(async () => { const {page} = getTestState(); - await page.setContent('
hello
'); + await page.setContent( + '
hello
' + ); Puppeteer.clearCustomQueryHandlers(); }); @@ -489,10 +491,60 @@ describe('Query handler tests', function () { expect(value).toMatchObject({textContent: 'world', tagName: 'BUTTON'}); }); - it('should work with commas', async () => { + it('should work with selector lists', async () => { const {page} = getTestState(); const elements = await page.$$('div, ::-p-text(world)'); expect(elements.length).toStrictEqual(2); }); + + const permute = (inputs: T[]): T[][] => { + const results: T[][] = []; + for (let i = 0; i < inputs.length; ++i) { + const permutation = permute( + inputs.slice(0, i).concat(inputs.slice(i + 1)) + ); + const value = inputs[i] as T; + if (permutation.length === 0) { + results.push([value]); + continue; + } + for (const part of permutation) { + results.push([value].concat(part)); + } + } + return results; + }; + + it('should match querySelector* ordering', async () => { + const {page} = getTestState(); + for (const list of permute(['div', 'button', 'span'])) { + const expected = await page.evaluate(selector => { + return [...document.querySelectorAll(selector)].map(element => { + return element.tagName; + }); + }, list.join(',')); + const elements = await page.$$( + list + .map(selector => { + return selector === 'button' ? '::-p-text(world)' : selector; + }) + .join(',') + ); + const actual = await Promise.all( + elements.map(element => { + return element.evaluate(element => { + return element.tagName; + }); + }) + ); + expect(actual.join()).toStrictEqual(expected.join()); + } + }); + + it('should not have duplicate elements from selector lists', async () => { + const {page} = getTestState(); + const elements = await page.$$('::-p-text(world), button'); + expect(elements.length).toStrictEqual(1); + }); }); });