Skip to content

Commit

Permalink
🎨 Improve overall static typing experience (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ousret authored Oct 6, 2023
1 parent d21fbee commit b064e47
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ repos:
hooks:
- id: mypy
args: [--check-untyped-defs]
exclude: 'tests/|setup.py'
exclude: 'tests/'
additional_dependencies: ['charset_normalizer', 'urllib3.future>=2.0.934', 'wassima>=1.0.1', 'idna', 'kiss_headers']
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Release History
===============

3.0.3 (2023-10-??)
------------------

**Misc**
- Static typing has been improved to provide a better development experience.

3.0.2 (2023-10-01)
------------------

Expand Down
4 changes: 2 additions & 2 deletions src/niquests/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@
[
_HV,
],
_HV,
typing.Optional[_HV],
]

HookType: typing.TypeAlias = typing.Dict[str, typing.List[HookCallableType]]
HookType: typing.TypeAlias = typing.Dict[str, typing.List[HookCallableType[_HV]]]

CacheLayerAltSvcType: typing.TypeAlias = typing.MutableMapping[
typing.Tuple[str, int], typing.Optional[typing.Tuple[str, int]]
Expand Down
2 changes: 1 addition & 1 deletion src/niquests/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
quic_cache_layer=quic_cache_layer,
)

def __getstate__(self):
def __getstate__(self) -> dict[str, typing.Any | None]:
return {attr: getattr(self, attr, None) for attr in self.__attrs__}

def __setstate__(self, state):
Expand Down
18 changes: 9 additions & 9 deletions src/niquests/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TLSClientCertType,
TLSVerifyType,
)
from .models import Response
from .models import PreparedRequest, Response

#: This is a non-thread safe in-memory cache for the AltSvc / h3
_SHARED_QUIC_CACHE: CacheLayerAltSvcType = {}
Expand All @@ -53,7 +53,7 @@ def request(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
"""Constructs and sends a :class:`Request <Request>`.
Expand Down Expand Up @@ -136,7 +136,7 @@ def get(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a GET request.
Expand Down Expand Up @@ -197,7 +197,7 @@ def options(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends an OPTIONS request.
Expand Down Expand Up @@ -257,7 +257,7 @@ def head(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a HEAD request.
Expand Down Expand Up @@ -320,7 +320,7 @@ def post(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a POST request.
Expand Down Expand Up @@ -394,7 +394,7 @@ def put(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a PUT request.
Expand Down Expand Up @@ -468,7 +468,7 @@ def patch(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a PATCH request.
Expand Down Expand Up @@ -539,7 +539,7 @@ def delete(
verify: TLSVerifyType = True,
stream: bool = False,
cert: TLSClientCertType | None = None,
hooks: HookType | None = None,
hooks: HookType[PreparedRequest | Response] | None = None,
retries: RetryType = DEFAULT_RETRIES,
) -> Response:
r"""Sends a DELETE request.
Expand Down
17 changes: 14 additions & 3 deletions src/niquests/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,33 @@
"""
from __future__ import annotations

import typing
from json import JSONDecodeError as CompatJSONDecodeError

from urllib3.exceptions import HTTPError as BaseHTTPError

if typing.TYPE_CHECKING:
from .models import PreparedRequest, Response


class RequestException(IOError):
"""There was an ambiguous exception that occurred while handling your
request.
"""

def __init__(self, *args, **kwargs):
response: Response | None
request: PreparedRequest | None

def __init__(self, *args, **kwargs) -> None:
"""Initialize RequestException with `request` and `response` objects."""
response = kwargs.pop("response", None)
self.response = response
self.request = kwargs.pop("request", None)
if response is not None and not self.request and hasattr(response, "request"):
if (
self.response is not None
and not self.request
and hasattr(self.response, "request")
):
self.request = self.response.request
super().__init__(*args, **kwargs)

Expand All @@ -33,7 +44,7 @@ class InvalidJSONError(RequestException):
class JSONDecodeError(InvalidJSONError, CompatJSONDecodeError):
"""Couldn't decode the text into json"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
Construct the JSONDecodeError instance first with all
args. Then use it's args to construct the IOError so that
Expand Down
8 changes: 5 additions & 3 deletions src/niquests/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@
]


def default_hooks() -> HookType:
def default_hooks() -> HookType[_HV]:
return {event: [] for event in HOOKS}


def dispatch_hook(
key: str, hooks: HookType | None, hook_data: _HV, **kwargs: typing.Any
key: str, hooks: HookType[_HV] | None, hook_data: _HV, **kwargs: typing.Any
) -> _HV:
"""Dispatches a hook dictionary on a given piece of data."""
if hooks is None:
return hook_data

callables: list[HookCallableType] | HookCallableType | None = hooks.get(key)
callables: list[HookCallableType[_HV]] | HookCallableType[_HV] | None = hooks.get(
key
)

if callables:
if callable(callables):
Expand Down
44 changes: 30 additions & 14 deletions src/niquests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CookiesType,
HeadersType,
HookCallableType,
HookType,
HttpAuthenticationType,
HttpMethodType,
MultiPartFilesAltType,
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
params: QueryParameterType | None = None,
auth: HttpAuthenticationType | None = None,
cookies: CookiesType | None = None,
hooks=None,
hooks: HookType | None = None,
json: typing.Any | None = None,
):
# Default empty dicts for dict params.
Expand All @@ -150,7 +151,7 @@ def __init__(
params = {} if params is None else params
hooks = {} if hooks is None else hooks

self.hooks = default_hooks()
self.hooks: HookType[Response | PreparedRequest] = default_hooks()
for k, v in list(hooks.items()):
self.register_hook(event=k, hook=v)

Expand All @@ -168,7 +169,10 @@ def __repr__(self) -> str:
return f"<Request [{self.method}]>"

def register_hook(
self, event: str, hook: HookCallableType | list[HookCallableType]
self,
event: str,
hook: HookCallableType[Response | PreparedRequest]
| list[HookCallableType[Response | PreparedRequest]],
) -> None:
"""Properly register a hook."""

Expand All @@ -180,7 +184,9 @@ def register_hook(
elif isinstance(hook, list):
self.hooks[event].extend(h for h in hook if callable(h))

def deregister_hook(self, event: str, hook: HookCallableType) -> bool:
def deregister_hook(
self, event: str, hook: HookCallableType[Response | PreparedRequest]
) -> bool:
"""Deregister a previously registered hook.
Returns True if the hook existed, False if not.
"""
Expand Down Expand Up @@ -243,7 +249,7 @@ def __init__(self) -> None:
#: request body to send to the server.
self.body: BodyType | None = None
#: dictionary of callback hooks, for internal usage.
self.hooks = default_hooks()
self.hooks: HookType[Response | PreparedRequest] = default_hooks()
#: integer denoting starting position of a readable file-like body.
self._body_position: int | object | None = None
#: valuable intel about the opened connection.
Expand All @@ -259,7 +265,7 @@ def prepare(
params: QueryParameterType | None = None,
auth: HttpAuthenticationType | None = None,
cookies: CookiesType | None = None,
hooks=None,
hooks: HookType[Response | PreparedRequest] | None = None,
json: typing.Any | None = None,
) -> None:
"""Prepares the entire request with the given parameters."""
Expand Down Expand Up @@ -876,7 +882,7 @@ def __repr__(self) -> str:

return f"<Response HTTP/{http_revision} [{self.status_code}]>"

def __bool__(self):
def __bool__(self) -> bool:
"""Returns True if :attr:`status_code` is less than 400.
This attribute checks if the status code of the response is between
Expand All @@ -886,9 +892,9 @@ def __bool__(self):
"""
return self.ok

def __iter__(self):
def __iter__(self) -> typing.Generator[bytes, None, None]:
"""Allows you to use a response as an iterator."""
return self.iter_content(128)
return self.iter_content(128) # type: ignore[return-value]

@property
def ok(self) -> bool:
Expand Down Expand Up @@ -931,7 +937,9 @@ def conn_info(self) -> ConnectionInfo | None:
return self.request.conn_info
return None

def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False):
def iter_content(
self, chunk_size: int = 1, decode_unicode: bool = False
) -> typing.Generator[bytes | str, None, None]:
"""Iterates over the response data. When stream=True is set on the
request, this avoids reading the content at once into memory for
large responses. The chunk size is the number of bytes it should
Expand All @@ -948,7 +956,7 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False):
available encoding based on the response.
"""

def generate():
def generate() -> typing.Generator[bytes, None, None]:
assert self.raw is not None
# Special case for urllib3.
if hasattr(self.raw, "stream"):
Expand Down Expand Up @@ -986,7 +994,7 @@ def generate():
chunks = reused_chunks if self._content_consumed else stream_chunks

if decode_unicode:
chunks = stream_decode_response_unicode(chunks, self)
return stream_decode_response_unicode(chunks, self)

return chunks

Expand All @@ -1002,6 +1010,14 @@ def iter_lines(
.. note:: This method is not reentrant safe.
"""
if (
delimiter is not None
and decode_unicode is False
and isinstance(delimiter, str)
):
raise ValueError(
"delimiter MUST match the desired output type. e.g. if decode_unicode is set to True, delimiter MUST be a str, otherwise we expect a bytes-like variable."
)

pending = None

Expand All @@ -1012,7 +1028,7 @@ def iter_lines(
chunk = pending + chunk

if delimiter:
lines = chunk.split(delimiter)
lines = chunk.split(delimiter) # type: ignore[arg-type]
else:
lines = chunk.splitlines()

Expand Down Expand Up @@ -1057,7 +1073,7 @@ def content(self) -> bytes | None:
if self.status_code == 0 or self.raw is None:
self._content = None
else:
self._content = b"".join(self.iter_content(CONTENT_CHUNK_SIZE)) or b""
self._content = b"".join(self.iter_content(CONTENT_CHUNK_SIZE)) or b"" # type: ignore[arg-type]

self._content_consumed = True
# don't need to release the connection; that's been handled by urllib3
Expand Down
Loading

0 comments on commit b064e47

Please sign in to comment.