From 26e8b12ad746f9c4bc54389c6f0e8bd6be9868f6 Mon Sep 17 00:00:00 2001 From: TAHRI Ahmed R Date: Thu, 11 Jan 2024 18:55:59 +0100 Subject: [PATCH] :sparkle: Add AsyncResponse to properly handle streams with AsyncSession (#64) --- HISTORY.md | 9 + README.md | 10 + docs/user/advanced.rst | 2 + docs/user/quickstart.rst | 85 +++++ src/niquests/__init__.py | 3 +- src/niquests/__version__.py | 4 +- src/niquests/_async.py | 624 ++++++++++++++++++++++++++++++++++-- src/niquests/adapters.py | 13 +- src/niquests/models.py | 7 + src/niquests/sessions.py | 4 + src/niquests/utils.py | 21 ++ tests/test_async.py | 70 +++- 12 files changed, 829 insertions(+), 23 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 4e71d9b3ce..c287d1c778 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,15 @@ Release History =============== +3.4.2 (2024-01-11) +------------------ + +**Fixed** +- Connection information kept targeting its original copy, thus always keeping the latest timings inside while expecting the historical ones. + +**Added** +- `AsyncSession` now returns a `AsyncResponse` when `stream` is set to True in order to handle properly streams in an async context. + 3.4.1 (2024-01-07) ------------------ diff --git a/README.md b/README.md index 1bd56390cc..cf783f3c10 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,16 @@ True >>> r.conn_info.established_latency datetime.timedelta(microseconds=38) ``` +or using async/await! you'll need to enclose the code within proper async function, see the docs for more. +```python +import niquests +>>> s = niquests.AsyncSession(resolver="doh+google://") +>>> r = await s.get('https://pie.dev/basic-auth/user/pass', auth=('user', 'pass'), stream=True) +>>> r + +>>> await r.json() +{'authenticated': True, ...} +``` Niquests allows you to send HTTP requests extremely easily. There’s no need to manually add query strings to your URLs, or to form-encode your `PUT` & `POST` data — but nowadays, just use the `json` method! diff --git a/docs/user/advanced.rst b/docs/user/advanced.rst index aec5aa1a3a..2516c78b1d 100644 --- a/docs/user/advanced.rst +++ b/docs/user/advanced.rst @@ -1276,6 +1276,8 @@ over QUIC. .. warning:: You cannot specify another hostname for security reasons. +.. note:: Using a custom DNS resolver can solve the problem as we can probe the HTTPS record for the given hostname and connect directly using HTTP/3 over QUIC. + Increase the default Alt-Svc cache size --------------------------------------- diff --git a/docs/user/quickstart.rst b/docs/user/quickstart.rst index 13d4bf9b2c..781679c79e 100644 --- a/docs/user/quickstart.rst +++ b/docs/user/quickstart.rst @@ -787,6 +787,91 @@ Look at this basic sample:: asyncio.run(main()) +.. warning:: Combining AsyncSession with ``multiplexed=True`` and passing ``stream=True`` produces ``AsyncResponse``, make sure to call ``await session.gather()`` before trying to access directly the lazy instance of response. + +AsyncResponse for streams +------------------------- + +Delaying the content consumption in an async context can be easily achieved using:: + + import niquests + import asyncio + + async def main() -> None: + + async with niquests.AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + + async for chunk in await r.iter_content(16): + print(chunk) + + + if __name__ == "__main__": + + asyncio.run(main()) + +Or simply by doing:: + + import niquests + import asyncio + + async def main() -> None: + + async with niquests.AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + payload = await r.json() + + if __name__ == "__main__": + + asyncio.run(main()) + +When you specify ``stream=True`` within a ``AsyncSession``, the returned object will be of type ``AsyncResponse``. +So that the following methods and properties will be coroutines (aka. awaitable): + +- iter_content(...) +- iter_lines(...) +- content +- json(...) +- text(...) + +When enabling multiplexing while in an async context, you will have to issue a call to ``await s.gather()`` +to avoid blocking your event loop. + +Here is a basic example of how you would do it:: + + import niquests + import asyncio + + + async def main() -> None: + + responses = [] + + async with niquests.AsyncSession(multiplexed=True) as s: + responses.append( + await s.get("https://pie.dev/get", stream=True) + ) + responses.append( + await s.get("https://pie.dev/get", stream=True) + ) + + print(responses) + + await s.gather() + + print(responses) + + for response in responses: + async for chunk in await response.iter_content(16): + print(chunk) + + + if __name__ == "__main__": + + asyncio.run(main()) + +.. warning:: Accessing a lazy ``AsyncResponse`` without a call to ``s.gather()`` will raise a warning. + DNS Resolution -------------- diff --git a/src/niquests/__init__.py b/src/niquests/__init__.py index a3ed5097ad..3b50e4850f 100644 --- a/src/niquests/__init__.py +++ b/src/niquests/__init__.py @@ -73,7 +73,7 @@ __url__, __version__, ) -from ._async import AsyncSession +from ._async import AsyncSession, AsyncResponse from .api import delete, get, head, options, patch, post, put, request from .exceptions import ( ConnectionError, @@ -131,4 +131,5 @@ "Session", "codes", "AsyncSession", + "AsyncResponse", ) diff --git a/src/niquests/__version__.py b/src/niquests/__version__.py index 7d84d756d4..8474cdeda6 100644 --- a/src/niquests/__version__.py +++ b/src/niquests/__version__.py @@ -9,9 +9,9 @@ __url__: str = "https://niquests.readthedocs.io" __version__: str -__version__ = "3.4.1" +__version__ = "3.4.2" -__build__: int = 0x030401 +__build__: int = 0x030402 __author__: str = "Kenneth Reitz" __author_email__: str = "me@kennethreitz.org" __license__: str = "Apache-2.0" diff --git a/src/niquests/_async.py b/src/niquests/_async.py index 22db239dcc..8035e2d112 100644 --- a/src/niquests/_async.py +++ b/src/niquests/_async.py @@ -1,6 +1,29 @@ from __future__ import annotations import typing +import json as _json +from charset_normalizer import from_bytes +import codecs + +if typing.TYPE_CHECKING: + from typing_extensions import Literal + +from ._compat import HAS_LEGACY_URLLIB3 + +if HAS_LEGACY_URLLIB3 is False: + from urllib3.exceptions import ( + DecodeError, + ProtocolError, + ReadTimeoutError, + SSLError, + ) +else: + from urllib3_future.exceptions import ( # type: ignore[assignment] + DecodeError, + ProtocolError, + ReadTimeoutError, + SSLError, + ) from ._constant import READ_DEFAULT_TIMEOUT, WRITE_DEFAULT_TIMEOUT from ._typing import ( @@ -19,9 +42,18 @@ TLSVerifyType, ) from .extensions._sync_to_async import sync_to_async +from .exceptions import ( + ChunkedEncodingError, + ConnectionError, + ContentDecodingError, + StreamConsumedError, +) +from .exceptions import JSONDecodeError as RequestsJSONDecodeError +from .exceptions import SSLError as RequestsSSLError from .hooks import dispatch_hook -from .models import PreparedRequest, Request, Response +from .models import PreparedRequest, Request, Response, ITER_CHUNK_SIZE from .sessions import Session +from .utils import astream_decode_response_unicode class AsyncSession(Session): @@ -40,12 +72,61 @@ async def __aexit__(self, exc, value, tb): super().__exit__, thread_sensitive=AsyncSession.disable_thread )() - async def send(self, request: PreparedRequest, **kwargs: typing.Any) -> Response: # type: ignore[override] + async def send( # type: ignore[override] + self, request: PreparedRequest, **kwargs: typing.Any + ) -> Response | AsyncResponse: # type: ignore[override] + if "stream" in kwargs and kwargs["stream"]: + kwargs["mutate_response_class"] = AsyncResponse return await sync_to_async( super().send, thread_sensitive=AsyncSession.disable_thread, )(request=request, **kwargs) + @typing.overload # type: ignore[override] + async def request( + self, + method: HttpMethodType, + url: str, + params: QueryParameterType | None = ..., + data: BodyType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + stream: Literal[False] = ..., + verify: TLSVerifyType | None = ..., + cert: TLSClientCertType | None = ..., + json: typing.Any | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def request( + self, + method: HttpMethodType, + url: str, + params: QueryParameterType | None = ..., + data: BodyType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + *, + stream: Literal[True], + verify: TLSVerifyType | None = ..., + cert: TLSClientCertType | None = ..., + json: typing.Any | None = ..., + ) -> AsyncResponse: + ... + async def request( # type: ignore[override] self, method: HttpMethodType, @@ -60,11 +141,11 @@ async def request( # type: ignore[override] allow_redirects: bool = True, proxies: ProxyType | None = None, hooks: HookType[PreparedRequest | Response] | None = None, - stream: bool | None = None, + stream: bool = False, verify: TLSVerifyType | None = None, cert: TLSClientCertType | None = None, json: typing.Any | None = None, - ) -> Response: + ) -> Response | AsyncResponse: if method.isupper() is False: method = method.upper() @@ -105,6 +186,44 @@ async def request( # type: ignore[override] return await self.send(prep, **send_kwargs) + @typing.overload # type: ignore[override] + async def get( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def get( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def get( # type: ignore[override] self, url: str, @@ -120,8 +239,8 @@ async def get( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "GET", url, params=params, @@ -137,6 +256,44 @@ async def get( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def options( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def options( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def options( # type: ignore[override] self, url: str, @@ -152,8 +309,8 @@ async def options( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "OPTIONS", url, params=params, @@ -169,6 +326,44 @@ async def options( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def head( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def head( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def head( # type: ignore[override] self, url: str, @@ -184,8 +379,8 @@ async def head( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "HEAD", url, params=params, @@ -201,6 +396,50 @@ async def head( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def post( + self, + url: str, + data: BodyType | None = ..., + json: typing.Any | None = ..., + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def post( + self, + url: str, + data: BodyType | None = ..., + json: typing.Any | None = ..., + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def post( # type: ignore[override] self, url: str, @@ -219,8 +458,8 @@ async def post( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "POST", url, data=data, @@ -239,6 +478,50 @@ async def post( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def put( + self, + url: str, + data: BodyType | None = ..., + *, + json: typing.Any | None = ..., + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def put( + self, + url: str, + data: BodyType | None = ..., + *, + json: typing.Any | None = ..., + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def put( # type: ignore[override] self, url: str, @@ -257,8 +540,8 @@ async def put( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "PUT", url, data=data, @@ -277,6 +560,50 @@ async def put( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def patch( + self, + url: str, + data: BodyType | None = ..., + *, + json: typing.Any | None = ..., + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def patch( + self, + url: str, + data: BodyType | None = ..., + *, + json: typing.Any | None = ..., + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + files: MultiPartFilesType | MultiPartFilesAltType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def patch( # type: ignore[override] self, url: str, @@ -295,8 +622,8 @@ async def patch( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "PATCH", url, data=data, @@ -315,6 +642,44 @@ async def patch( # type: ignore[override] cert=cert, ) + @typing.overload # type: ignore[override] + async def delete( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[False] = ..., + cert: TLSClientCertType | None = ..., + ) -> Response: + ... + + @typing.overload # type: ignore[override] + async def delete( + self, + url: str, + *, + params: QueryParameterType | None = ..., + headers: HeadersType | None = ..., + cookies: CookiesType | None = ..., + auth: HttpAuthenticationType | None = ..., + timeout: TimeoutType | None = ..., + allow_redirects: bool = ..., + proxies: ProxyType | None = ..., + hooks: HookType[PreparedRequest | Response] | None = ..., + verify: TLSVerifyType = ..., + stream: Literal[True], + cert: TLSClientCertType | None = ..., + ) -> AsyncResponse: + ... + async def delete( # type: ignore[override] self, url: str, @@ -330,8 +695,8 @@ async def delete( # type: ignore[override] verify: TLSVerifyType = True, stream: bool = False, cert: TLSClientCertType | None = None, - ) -> Response: - return await self.request( + ) -> Response | AsyncResponse: + return await self.request( # type: ignore[call-overload,misc] "DELETE", url, params=params, @@ -352,3 +717,226 @@ async def gather(self, *responses: Response, max_fetch: int | None = None) -> No super().gather, thread_sensitive=AsyncSession.disable_thread, )(*responses, max_fetch=max_fetch) + + +class AsyncResponse(Response): + def __aenter__(self) -> AsyncResponse: + return self + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in await self.iter_content(128): + yield chunk + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + @typing.overload # type: ignore[override] + async def iter_content( + self, chunk_size: int = ..., decode_unicode: Literal[False] = ... + ) -> typing.AsyncGenerator[bytes, None]: + ... + + @typing.overload # type: ignore[override] + async def iter_content( + self, chunk_size: int = ..., *, decode_unicode: Literal[True] + ) -> typing.AsyncGenerator[str, None]: + ... + + async def iter_content( # type: ignore[override] + self, chunk_size: int = 1, decode_unicode: bool = False + ) -> typing.AsyncGenerator[bytes | str, None]: + async def generate() -> ( + typing.AsyncGenerator[ + bytes, + None, + ] + ): + assert self.raw is not None + + while True: + try: + chunk = await sync_to_async( + self.raw.read, thread_sensitive=AsyncSession.disable_thread + )( + chunk_size, + decode_content=True, + ) + except ProtocolError as e: + raise ChunkedEncodingError(e) + except DecodeError as e: + raise ContentDecodingError(e) + except ReadTimeoutError as e: + raise ConnectionError(e) + except SSLError as e: + raise RequestsSSLError(e) + + if not chunk: + break + + yield chunk + + self._content_consumed = True + + if self._content_consumed and isinstance(self._content, bool): + raise StreamConsumedError() + elif chunk_size is not None and not isinstance(chunk_size, int): + raise TypeError( + f"chunk_size must be an int, it is instead a {type(chunk_size)}." + ) + + stream_chunks = generate() + + if decode_unicode: + return astream_decode_response_unicode(stream_chunks, self) + + return stream_chunks + + @typing.overload # type: ignore[override] + async def iter_lines( + self, + chunk_size: int = ..., + decode_unicode: Literal[False] = ..., + delimiter: str | bytes | None = ..., + ) -> typing.AsyncGenerator[bytes, None]: + ... + + @typing.overload # type: ignore[override] + async def iter_lines( + self, + chunk_size: int = ..., + *, + decode_unicode: Literal[True], + delimiter: str | bytes | None = ..., + ) -> typing.AsyncGenerator[str, None]: + ... + + async def iter_lines( # type: ignore[misc] + self, + chunk_size: int = ITER_CHUNK_SIZE, + decode_unicode: bool = False, + delimiter: str | bytes | None = None, + ) -> typing.AsyncGenerator[bytes | str, None]: + if ( + delimiter is not None + and decode_unicode is False + and isinstance(delimiter, str) + ): + raise ValueError( + "delimiter MUST match the desired output type. e.g. if decode_unicode is set to True, delimiter MUST be a str, otherwise we expect a bytes-like variable." + ) + + pending = None + + async for chunk in self.iter_content( # type: ignore[call-overload] + chunk_size=chunk_size, decode_unicode=decode_unicode + ): + if pending is not None: + chunk = pending + chunk + + if delimiter: + lines = chunk.split(delimiter) # type: ignore[arg-type] + else: + lines = chunk.splitlines() + + if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: + pending = lines.pop() + else: + pending = None + + async for line in lines: + yield line + + if pending is not None: + yield pending + + @property + async def content(self) -> bytes | None: # type: ignore[override] + return await sync_to_async( + getattr, thread_sensitive=AsyncSession.disable_thread + )(super(), "content") + + @property + async def text(self) -> str | None: # type: ignore[override] + content = await self.content + + if not content: + return "" + + if self.encoding is not None: + try: + info = codecs.lookup(self.encoding) + + if ( + hasattr(info, "_is_text_encoding") + and info._is_text_encoding is False + ): + return None + except LookupError: + #: We cannot accept unsupported or nonexistent encoding. Override. + self.encoding = None + + # Fallback to auto-detected encoding. + if self.encoding is None: + encoding_guess = from_bytes(content).best() + + if encoding_guess: + #: We shall cache this inference. + self.encoding = encoding_guess.encoding + return str(encoding_guess) + + if self.encoding is None: + return None + + return str(content, self.encoding, errors="replace") + + async def json(self, **kwargs: typing.Any) -> typing.Any: # type: ignore[override] + content = await self.content + + if not content or "json" not in self.headers.get("content-type", "").lower(): + raise RequestsJSONDecodeError( + "response content is not JSON", await self.text or "", 0 + ) + + if not self.encoding: + # No encoding set. JSON RFC 4627 section 3 states we should expect + # UTF-8, -16 or -32. Detect which one to use; If the detection or + # decoding fails, fall back to `self.text` (using charset_normalizer to make + # a best guess). + encoding_guess = from_bytes( + content, + cp_isolation=[ + "ascii", + "utf-8", + "utf-16", + "utf-32", + "utf-16-le", + "utf-16-be", + "utf-32-le", + "utf-32-be", + ], + ).best() + + if encoding_guess is not None: + try: + return _json.loads(str(encoding_guess), **kwargs) + except _json.JSONDecodeError as e: + raise RequestsJSONDecodeError(e.msg, e.doc, e.pos) + + plain_content = await self.text + + if plain_content is None: + raise RequestsJSONDecodeError( + "response cannot lead to decodable JSON", "", 0 + ) + + try: + return _json.loads(plain_content, **kwargs) + except _json.JSONDecodeError as e: + # Catch JSON-related errors and raise as requests.JSONDecodeError + # This aliases json.JSONDecodeError and simplejson.JSONDecodeError + raise RequestsJSONDecodeError(e.msg, e.doc, e.pos) + + async def close(self) -> None: # type: ignore[override] + await sync_to_async( + super().close, thread_sensitive=AsyncSession.disable_thread + )() diff --git a/src/niquests/adapters.py b/src/niquests/adapters.py index 62f4032ce6..5d3b72e5ac 100644 --- a/src/niquests/adapters.py +++ b/src/niquests/adapters.py @@ -150,6 +150,7 @@ def send( on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None, multiplexed: bool = False, + mutate_response_class: type | None = None, ) -> Response: """Sends PreparedRequest object. Returns Response object. @@ -653,6 +654,7 @@ def send( on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None, multiplexed: bool = False, + mutate_response_class: type | None = None, ) -> Response: """Sends PreparedRequest object. Returns Response object. @@ -790,7 +792,16 @@ def send( else: raise - return self.build_response(request, resp_or_promise) + r = self.build_response(request, resp_or_promise) + + if mutate_response_class: + if not issubclass(mutate_response_class, Response): + raise TypeError( + f"Unable to mutate Response to {mutate_response_class} as it does not inherit from Response." + ) + r.__class__ = mutate_response_class + + return r def _future_handler(self, response: Response, low_resp: BaseHTTPResponse) -> None: stream = typing.cast( diff --git a/src/niquests/models.py b/src/niquests/models.py index 4b7fe971c7..47d3652931 100644 --- a/src/niquests/models.py +++ b/src/niquests/models.py @@ -8,6 +8,7 @@ import codecs import datetime +import warnings # Import encoding now, to avoid implicit import later. # Implicit import within threads may cause LookupError when standard library is in a ZIP, @@ -984,6 +985,12 @@ def lazy(self) -> bool: def __getattribute__(self, item): if item in Response.__lazy_attrs__ and self.lazy: + if self.__class__ is not Response and "Async" in str(self.__class__): + warnings.warn( + "Accessing a lazy response in an asynchronous context is going to block the event loop. " + "Use await session.gather() instead before accessing the response.", + ResourceWarning, + ) self._gather() return super().__getattribute__(item) diff --git a/src/niquests/sessions.py b/src/niquests/sessions.py index cff6247d4e..10a701ef0e 100644 --- a/src/niquests/sessions.py +++ b/src/niquests/sessions.py @@ -14,6 +14,7 @@ import warnings from collections import OrderedDict from collections.abc import Mapping +from copy import deepcopy from datetime import timedelta from http import cookiejar as cookielib from http.cookiejar import CookieJar @@ -1126,6 +1127,9 @@ def handle_upload_progress( # Send the request r = adapter.send(request, **kwargs) + # Make sure the timings data are kept as is, conn_info is a reference to + # urllib3-future conn_info. + request.conn_info = deepcopy(request.conn_info) # We are leveraging a multiplexed connection if r.raw is None: diff --git a/src/niquests/utils.py b/src/niquests/utils.py index d44a763c3e..9c6c4ba6a3 100644 --- a/src/niquests/utils.py +++ b/src/niquests/utils.py @@ -514,6 +514,27 @@ def stream_decode_response_unicode( yield rv +async def astream_decode_response_unicode( + iterator: typing.AsyncGenerator[bytes, None], r: Response +) -> typing.AsyncGenerator[bytes | str, None]: + """Stream decodes an iterator.""" + + if r.encoding is None: + async for chunk in iterator: + yield chunk + return + + decoder = codecs.getincrementaldecoder(r.encoding)(errors="replace") + + async for chunk in iterator: + rv = decoder.decode(chunk) + if rv: + yield rv + rv = decoder.decode(b"", final=True) + if rv: + yield rv + + _SV = typing.TypeVar("_SV", str, bytes) diff --git a/tests/test_async.py b/tests/test_async.py index 9641551a6e..37d31885ca 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,10 +1,11 @@ from __future__ import annotations import asyncio +import json import pytest -from niquests import AsyncSession +from niquests import AsyncSession, AsyncResponse, Response @pytest.mark.usefixtures("requires_wan") @@ -63,6 +64,7 @@ async def test_awaitable_get_direct_access_lazy(self): resp = await s.get("https://pie.dev/get") assert resp.lazy is True + assert isinstance(resp, Response) assert resp.status_code == 200 async def test_concurrent_task_get(self): @@ -87,3 +89,69 @@ async def emit(): assert len(responses_bar) == 2 assert all(r.status_code == 200 for r in responses_foo + responses_bar) + + async def test_with_stream_json(self): + async with AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + assert isinstance(r, AsyncResponse) + assert r.ok + payload = await r.json() + assert payload + + async def test_with_stream_text(self): + async with AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + assert isinstance(r, AsyncResponse) + assert r.ok + payload = await r.text + assert payload is not None + + async def test_with_stream_iter_decode(self): + async with AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + assert isinstance(r, AsyncResponse) + assert r.ok + payload = "" + + async for chunk in await r.iter_content(16, decode_unicode=True): + payload += chunk + + assert json.loads(payload) + + async def test_with_stream_iter_raw(self): + async with AsyncSession() as s: + r = await s.get("https://pie.dev/get", stream=True) + assert isinstance(r, AsyncResponse) + assert r.ok + payload = b"" + + async for chunk in await r.iter_content(16): + payload += chunk + + assert json.loads(payload.decode()) + + async def test_concurrent_task_get_with_stream(self): + async def emit(): + responses = [] + + async with AsyncSession(multiplexed=True) as s: + responses.append(await s.get("https://pie.dev/get", stream=True)) + responses.append(await s.get("https://pie.dev/delay/5", stream=True)) + + await s.gather() + + for response in responses: + await response.content + + return responses + + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + + responses_foo = await foo + responses_bar = await bar + + assert len(responses_foo) == 2 + assert len(responses_bar) == 2 + + assert all(r.status_code == 200 for r in responses_foo + responses_bar)