Skip to content

Commit

Permalink
🐛 fix using niquests in more than one loop (#191)
Browse files Browse the repository at this point in the history
resolve #190
  • Loading branch information
Ousret authored Dec 23, 2024
2 parents 6bf7791 + 5b0c9b9 commit 32d2716
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 70 deletions.
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Release History
===============

3.11.4 (2024-12-23)
-------------------

**Fixed**
- Invoking ``niquests`` in more than one event loop, even if no loop concurrence occurs. (#190)
The faulty part was the shared OCSP cache that was automatically bound the first event loop and
could not be shared across more than one loop. Keep in mind that Niquests async is task safe within
a single event loop. Sharing a single AsyncSession across more than one event loop is unpredictable.
We've waived that limitation by binding the ocsp cache to a single `Session`. (both sync & async)
- Undesirable ``socket.timeout`` error coming from the ocsp checker when running Python < 3.9.

3.11.3 (2024-12-13)
-------------------

Expand Down
14 changes: 13 additions & 1 deletion src/niquests/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def __init__(
else AsyncQuicSharedCache(max_size=12_288)
)

#: Don't try to manipulate this object.
#: It cannot be pickled and accessing this object may cause
#: unattended errors.
self._ocsp_cache: typing.Any | None = None

# Default connection adapters.
self.adapters: OrderedDict[str, AsyncBaseAdapter] = OrderedDict() # type: ignore[assignment]
self.mount(
Expand Down Expand Up @@ -387,17 +392,24 @@ async def on_post_connection(conn_info: ConnectionInfo) -> None:
)

try:
from .extensions._async_ocsp import verify as ocsp_verify
from .extensions._async_ocsp import (
verify as ocsp_verify,
InMemoryRevocationStatus,
)
except ImportError:
pass
else:
if self._ocsp_cache is None:
self._ocsp_cache = InMemoryRevocationStatus()

await ocsp_verify(
ptr_request,
strict_ocsp_enabled,
0.2 if not strict_ocsp_enabled else 1.0,
kwargs["proxies"],
resolver=self.resolver,
happy_eyeballs=self._happy_eyeballs,
cache=self._ocsp_cache,
)

# don't trigger pre_send for redirects
Expand Down
12 changes: 6 additions & 6 deletions src/niquests/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from urllib3 import __version__

HAS_LEGACY_URLLIB3: bool = int(__version__.split(".")[-1]) < 900
except (ValueError, ImportError):
except (ValueError, ImportError): # Defensive: tested in separate CI
# Means one of the two cases:
# 1) urllib3 does not exist -> fallback to urllib3_future
# 2) urllib3 exist but not fork -> fallback to urllib3_future
HAS_LEGACY_URLLIB3 = True

if HAS_LEGACY_URLLIB3:
if HAS_LEGACY_URLLIB3: # Defensive: tested in separate CI
import urllib3_future
else:
urllib3_future = None # type: ignore[assignment]
Expand All @@ -24,14 +24,14 @@
urllib3.Timeout # noqa
urllib3.Retry # noqa
urllib3.__version__ # noqa
except (ImportError, AttributeError):
except (ImportError, AttributeError): # Defensive: tested in separate CI
urllib3 = None # type: ignore[assignment]


if (urllib3 is None and urllib3_future is None) or (
HAS_LEGACY_URLLIB3 and urllib3_future is None
):
raise RuntimeError(
raise RuntimeError( # Defensive: tested in separate CI
"This is awkward but your environment is missing urllib3-future. "
"Your environment seems broken. "
"You may fix this issue by running `python -m pip install niquests -U` "
Expand All @@ -40,7 +40,7 @@

if urllib3 is not None:
T = typing.TypeVar("T", urllib3.Timeout, urllib3.Retry)
else:
else: # Defensive: tested in separate CI
T = typing.TypeVar("T", urllib3_future.Timeout, urllib3_future.Retry) # type: ignore


Expand All @@ -49,7 +49,7 @@ def urllib3_ensure_type(o: T) -> T:
if urllib3 is None:
return o

if HAS_LEGACY_URLLIB3:
if HAS_LEGACY_URLLIB3: # Defensive: tested in separate CI
if "urllib3_future" not in str(type(o)):
assert urllib3_future is not None

Expand Down
32 changes: 14 additions & 18 deletions src/niquests/extensions/_async_ocsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,14 @@ def save(
self._timings.pop(0)


_SharedRevocationStatusCache = InMemoryRevocationStatus()


async def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: AsyncBaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationStatus | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info

Expand All @@ -328,26 +326,27 @@ async def verify(
if not endpoints:
return

if cache is None:
cache = InMemoryRevocationStatus()

peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)

async with _SharedRevocationStatusCache.lock(peer_certificate):
async with cache.lock(peer_certificate):
# this feature, by default, is reserved for a reasonable usage.
if not strict:
mean_rate_sec = _SharedRevocationStatusCache.rate()
cache_count = len(_SharedRevocationStatusCache)
mean_rate_sec = cache.rate()
cache_count = len(cache)

if cache_count >= 10 and mean_rate_sec <= 1.0:
_SharedRevocationStatusCache.hold = True
cache.hold = True

if _SharedRevocationStatusCache.hold:
if cache.hold:
return

cached_response = _SharedRevocationStatusCache.check(peer_certificate)
cached_response = cache.check(peer_certificate)

if cached_response is not None:
issuer_certificate = _SharedRevocationStatusCache.get_issuer_of(
peer_certificate
)
issuer_certificate = cache.get_issuer_of(peer_certificate)

if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
Expand Down Expand Up @@ -394,9 +393,7 @@ async def verify(
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = _SharedRevocationStatusCache.get_issuer_of(
peer_certificate
)
issuer_certificate = cache.get_issuer_of(peer_certificate)

# If not, try to ask nicely the remote to give us the certificate chain, and extract
# from it the immediate issuer.
Expand Down Expand Up @@ -427,6 +424,7 @@ async def verify(

except (
socket.gaierror,
socket.timeout,
TimeoutError,
ConnectionError,
AttributeError,
Expand Down Expand Up @@ -548,9 +546,7 @@ async def verify(
)
return

_SharedRevocationStatusCache.save(
peer_certificate, issuer_certificate, ocsp_resp
)
cache.save(peer_certificate, issuer_certificate, ocsp_resp)

if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
Expand Down
46 changes: 25 additions & 21 deletions src/niquests/extensions/_ocsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def _parse_x509_der_cached(der: bytes) -> Certificate:
return Certificate(der)


@lru_cache(maxsize=64)
def _fingerprint_raw_data(payload: bytes) -> str:
return "".join([format(i, "02x") for i in sha256(payload).digest()])


def _str_fingerprint_of(certificate: Certificate) -> str:
return ":".join(
[format(i, "02x") for i in sha256(certificate.public_bytes()).digest()]
)
return _fingerprint_raw_data(certificate.public_bytes())


def readable_revocation_reason(flag: ReasonFlags | None) -> str | None:
Expand Down Expand Up @@ -293,16 +296,14 @@ def save(
self._timings.pop(0)


_SharedRevocationStatusCache = InMemoryRevocationStatus()


def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: BaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationStatus | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info

Expand All @@ -325,24 +326,25 @@ def verify(
if not endpoints:
return

if cache is None:
cache = InMemoryRevocationStatus()

# this feature, by default, is reserved for a reasonable usage.
if not strict:
mean_rate_sec = _SharedRevocationStatusCache.rate()
cache_count = len(_SharedRevocationStatusCache)
mean_rate_sec = cache.rate()
cache_count = len(cache)

if cache_count >= 10 and mean_rate_sec <= 1.0:
_SharedRevocationStatusCache.hold = True
cache.hold = True

if _SharedRevocationStatusCache.hold:
if cache.hold:
return

peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
cached_response = _SharedRevocationStatusCache.check(peer_certificate)
cached_response = cache.check(peer_certificate)

if cached_response is not None:
issuer_certificate = _SharedRevocationStatusCache.get_issuer_of(
peer_certificate
)
issuer_certificate = cache.get_issuer_of(peer_certificate)

if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
Expand Down Expand Up @@ -387,9 +389,7 @@ def verify(
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = _SharedRevocationStatusCache.get_issuer_of(
peer_certificate
)
issuer_certificate = cache.get_issuer_of(peer_certificate)

# If not, try to ask nicely the remote to give us the certificate chain, and extract
# from it the immediate issuer.
Expand Down Expand Up @@ -418,7 +418,13 @@ def verify(
else:
issuer_certificate = None

except (socket.gaierror, TimeoutError, ConnectionError, AttributeError):
except (
socket.gaierror,
socket.timeout,
TimeoutError,
ConnectionError,
AttributeError,
):
pass
except ValueError:
issuer_certificate = None
Expand Down Expand Up @@ -534,9 +540,7 @@ def verify(
)
return

_SharedRevocationStatusCache.save(
peer_certificate, issuer_certificate, ocsp_resp
)
cache.save(peer_certificate, issuer_certificate, ocsp_resp)

if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
Expand Down
13 changes: 12 additions & 1 deletion src/niquests/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ def __init__(
else QuicSharedCache(max_size=12_288)
)

#: Don't try to manipulate this object.
#: It cannot be pickled and accessing this object may cause
#: unattended errors.
self._ocsp_cache: typing.Any | None = None

# Default connection adapters.
self.adapters: OrderedDict[str, BaseAdapter] = OrderedDict()
self.mount(
Expand Down Expand Up @@ -1136,17 +1141,23 @@ def on_post_connection(conn_info: ConnectionInfo) -> None:
)

try:
from .extensions._ocsp import verify as ocsp_verify
from .extensions._ocsp import (
verify as ocsp_verify,
InMemoryRevocationStatus,
)
except ImportError:
pass
else:
if self._ocsp_cache is None:
self._ocsp_cache = InMemoryRevocationStatus()
ocsp_verify(
ptr_request,
strict_ocsp_enabled,
0.2 if not strict_ocsp_enabled else 1.0,
kwargs["proxies"],
resolver=self.resolver,
happy_eyeballs=self._happy_eyeballs,
cache=self._ocsp_cache,
)

# don't trigger pre_send for redirects
Expand Down
34 changes: 34 additions & 0 deletions tests/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def hook(value):
return value[1:]


async def ahook(value):
return value[1:]


@pytest.mark.parametrize(
"hooks_list, result",
(
Expand Down Expand Up @@ -36,6 +40,36 @@ def test_hooks_with_kwargs(hooks_list, result):
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"hooks_list, result",
(
(ahook, "ata"),
([ahook, lambda x: None, hook], "ta"),
),
)
async def test_ahooks(hooks_list, result):
assert (
await hooks.async_dispatch_hook("response", {"response": hooks_list}, "Data")
) == result


@pytest.mark.asyncio
@pytest.mark.parametrize(
"hooks_list, result",
(
(hook, "ata"),
([hook, lambda x: None, ahook], "ta"),
),
)
async def test_ahooks_with_kwargs(hooks_list, result):
assert (
await hooks.async_dispatch_hook(
"response", {"response": hooks_list}, "Data", should_not_crash=True
)
) == result


def test_default_hooks():
assert hooks.default_hooks() == {
"pre_request": [],
Expand Down
Loading

0 comments on commit 32d2716

Please sign in to comment.