Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add requests argument to EnqueueLinksFunction #1024

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,20 +324,40 @@ def __call__(

@docs_group('Functions')
class EnqueueLinksFunction(Protocol):
"""A function for enqueueing new URLs to crawl based on elements selected by a given selector.
"""A function for enqueueing new URLs to crawl based on elements selected by a given selector or explicit requests.

It extracts URLs from the current page and enqueues them for further crawling. It allows filtering through
selectors and other options. You can also specify labels and user data to be associated with the newly
created `Request` objects.
It adds explicitly passed `requests` to the `RequestManager` or it extracts URLs from the current page and enqueues
them for further crawling. It allows filtering through selectors and other options. You can also specify labels and
user data to be associated with the newly created `Request` objects.

It should not be called with `selector`, `label`, `user_data` or `transform_request_function` arguments together
with `requests` argument.
"""

@overload
def __call__(
self,
*,
selector: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]: ...

@overload
def __call__(
self, *, requests: Sequence[str | Request] | None = None, **kwargs: Unpack[EnqueueLinksKwargs]
) -> Coroutine[None, None, None]: ...

def __call__(
self,
*,
selector: str = 'a',
selector: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
requests: Sequence[str | Request] | None = None,
Copy link
Collaborator

@janbuchar janbuchar Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now noticed that the JS counterpart accepts just urls as an array of strings. We should either restrict this, or extend the JS version 🙂

If we choose restricting this one, then most of the other parameters (barring selector) would actually start making sense in combination with urls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep it as it is for consistency, since we use request: str | Request everywhere else.

**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]:
"""A call dunder method.
Expand All @@ -354,6 +374,7 @@ def __call__(
- Modified `RequestOptions` to update the request configuration,
- `'skip'` to exclude the request from being enqueued,
- `'unchanged'` to use the original request options without modification.
requests: Requests to be added to the `RequestManager`.
**kwargs: Additional keyword arguments.
"""

Expand Down
43 changes: 37 additions & 6 deletions src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, Generic
from typing import TYPE_CHECKING, Any, Callable, Generic, Union

from pydantic import ValidationError
from typing_extensions import TypeVar
Expand All @@ -17,7 +17,7 @@
from ._http_crawling_context import HttpCrawlingContext, ParsedHttpCrawlingContext, TParseResult, TSelectResult

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable
from collections.abc import AsyncGenerator, Awaitable, Sequence

from typing_extensions import Unpack

Expand Down Expand Up @@ -143,18 +143,18 @@ def _create_enqueue_links_function(
Awaitable that is used for extracting links from parsed content and enqueuing them to the crawl.
"""

async def enqueue_links(
async def extract_links(
*,
selector: str = 'a',
label: str | None = None,
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
) -> list[str | Request]:
kwargs.setdefault('strategy', 'same-hostname')

requests = list[Request]()
requests = list[Union[str, Request]]()
base_user_data = user_data or {}

for link in self._parser.find_links(parsed_content, selector=selector):
Expand Down Expand Up @@ -183,8 +183,39 @@ async def enqueue_links(
continue

requests.append(request)
return requests

await context.add_requests(requests, **kwargs)
async def enqueue_links(
*,
selector: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
requests: Sequence[str | Request] | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
kwargs.setdefault('strategy', 'same-hostname')

if requests:
if any((selector, label, user_data, transform_request_function)):
raise ValueError(
'You cannot provide `selector`, `label`, `user_data` or '
'`transform_request_function` arguments when `requests` is provided.'
)
# Add directly passed requests.
await context.add_requests(requests or list[Union[str, Request]](), **kwargs)
else:
# Add requests from extracted links.
await context.add_requests(
await extract_links(
selector=selector or 'a',
label=label,
user_data=user_data,
transform_request_function=transform_request_function,
),
**kwargs,
)

return enqueue_links

Expand Down
46 changes: 39 additions & 7 deletions src/crawlee/crawlers/_playwright/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic
from typing import TYPE_CHECKING, Any, Callable, Generic, Union

from pydantic import ValidationError
from typing_extensions import NotRequired, TypedDict, TypeVar
Expand All @@ -25,7 +25,7 @@
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable, Mapping
from collections.abc import AsyncGenerator, Awaitable, Mapping, Sequence
from pathlib import Path

from playwright.async_api import Page
Expand Down Expand Up @@ -230,19 +230,18 @@ async def _navigate(
pw_cookies = await self._get_cookies(context.page)
context.session.cookies.set_cookies_from_playwright_format(pw_cookies)

async def enqueue_links(
async def extract_links(
*,
selector: str = 'a',
label: str | None = None,
user_data: dict | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
"""The `PlaywrightCrawler` implementation of the `EnqueueLinksFunction` function."""
) -> list[str | Request]:
kwargs.setdefault('strategy', 'same-hostname')

requests = list[Request]()
requests = list[Union[str, Request]]()
base_user_data = user_data or {}

elements = await context.page.query_selector_all(selector)
Expand Down Expand Up @@ -278,7 +277,40 @@ async def enqueue_links(

requests.append(request)

await context.add_requests(requests, **kwargs)
return requests

async def enqueue_links(
*,
selector: str | None = None,
label: str | None = None,
user_data: dict | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
requests: Sequence[str | Request] | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
"""The `PlaywrightCrawler` implementation of the `EnqueueLinksFunction` function."""
kwargs.setdefault('strategy', 'same-hostname')

if requests:
if any((selector, label, user_data, transform_request_function)):
raise ValueError(
'You cannot provide `selector`, `label`, `user_data` or `transform_request_function` '
'arguments when `requests` is provided.'
)
# Add directly passed requests.
await context.add_requests(requests or list[Union[str, Request]](), **kwargs)
else:
# Add requests from extracted links.
await context.add_requests(
await extract_links(
selector=selector or 'a',
label=label,
user_data=user_data,
transform_request_function=transform_request_function,
),
**kwargs,
)

yield PlaywrightCrawlingContext(
request=context.request,
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/crawlers/_parsel/test_parsel_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import respx
from httpx import Response

from crawlee import ConcurrencySettings, HttpHeaders, RequestTransformAction
from crawlee import ConcurrencySettings, HttpHeaders, Request, RequestTransformAction
from crawlee.crawlers import ParselCrawler

if TYPE_CHECKING:
Expand Down Expand Up @@ -151,6 +151,26 @@ async def request_handler(context: ParselCrawlingContext) -> None:
}


async def test_enqueue_links_with_incompatible_kwargs_raises_error(server: respx.MockRouter) -> None:
"""Call `enqueue_links` with arguments that can't be used together."""
requests = ['https://www.test.io/']
crawler = ParselCrawler(max_request_retries=1)
exceptions = []

@crawler.router.default_handler
async def request_handler(context: ParselCrawlingContext) -> None:
try:
await context.enqueue_links(requests=[Request.from_url('https://test.io/asdf')], selector='a') # type:ignore[call-overload] # Testing runtime enforcement of the overloads.
except Exception as e:
exceptions.append(e)

await crawler.run(requests)

assert server['index_endpoint'].called
assert len(exceptions) == 1
assert type(exceptions[0]) is ValueError


async def test_enqueue_links_selector(server: respx.MockRouter) -> None:
crawler = ParselCrawler()
visit = mock.Mock()
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/crawlers/_playwright/test_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,29 @@ async def request_handler(context: PlaywrightCrawlingContext) -> None:
assert all(url.startswith('https://crawlee.dev/docs/examples') for url in visited)


async def test_enqueue_links_with_incompatible_kwargs_raises_error() -> None:
"""Call `enqueue_links` with arguments that can't be used together."""
requests = ['https://www.something.com']
crawler = PlaywrightCrawler(max_request_retries=1)
exceptions = []

@crawler.pre_navigation_hook
async def some_hook(context: PlaywrightPreNavCrawlingContext) -> None:
await context.page.route('**/*', lambda route: route.fulfill(status=200))

@crawler.router.default_handler
async def request_handler(context: PlaywrightCrawlingContext) -> None:
try:
await context.enqueue_links(requests=[Request.from_url('https://www.whatever.com')], selector='a') # type:ignore[call-overload] # Testing runtime enforcement of the overloads.
except Exception as e:
exceptions.append(e)

await crawler.run(requests)

assert len(exceptions) == 1
assert type(exceptions[0]) is ValueError


async def test_enqueue_links_with_transform_request_function() -> None:
crawler = PlaywrightCrawler()
visit = mock.Mock()
Expand Down
Loading