11from types import TracebackType
2+ from contextvars import ContextVar
23from contextlib import contextmanager , asynccontextmanager
34from typing import (
45 Any ,
@@ -129,14 +130,18 @@ def __init__(
129130 base_url , accept_format , previews , user_agent , follow_redirects , timeout
130131 )
131132
132- self .__sync_client : Optional [httpx .Client ] = None
133- self .__async_client : Optional [httpx .AsyncClient ] = None
133+ self .__sync_client : ContextVar [Optional [httpx .Client ]] = ContextVar (
134+ "sync_client" , default = None
135+ )
136+ self .__async_client : ContextVar [Optional [httpx .AsyncClient ]] = ContextVar (
137+ "async_client" , default = None
138+ )
134139
135140 # sync context
136141 def __enter__ (self ):
137- if self .__sync_client is not None :
142+ if self .__sync_client . get () is not None :
138143 raise RuntimeError ("Cannot enter sync context twice" )
139- self .__sync_client = self ._create_sync_client ()
144+ self .__sync_client . set ( self ._create_sync_client () )
140145 return self
141146
142147 def __exit__ (
@@ -145,14 +150,14 @@ def __exit__(
145150 exc_value : Optional [BaseException ] = None ,
146151 traceback : Optional [TracebackType ] = None ,
147152 ):
148- cast (httpx .Client , self .__sync_client ).close ()
149- self .__sync_client = None
153+ cast (httpx .Client , self .__sync_client . get () ).close ()
154+ self .__sync_client . set ( None )
150155
151156 # async context
152157 async def __aenter__ (self ):
153- if self .__async_client is not None :
158+ if self .__async_client . get () is not None :
154159 raise RuntimeError ("Cannot enter async context twice" )
155- self .__async_client = self ._create_async_client ()
160+ self .__async_client . set ( self ._create_async_client () )
156161 return self
157162
158163 async def __aexit__ (
@@ -161,8 +166,8 @@ async def __aexit__(
161166 exc_value : Optional [BaseException ] = None ,
162167 traceback : Optional [TracebackType ] = None ,
163168 ):
164- await cast (httpx .AsyncClient , self .__async_client ).aclose ()
165- self .__async_client = None
169+ await cast (httpx .AsyncClient , self .__async_client . get () ).aclose ()
170+ self .__async_client . set ( None )
166171
167172 # default args for creating client
168173 def _get_client_defaults (self ):
@@ -184,8 +189,8 @@ def _create_sync_client(self) -> httpx.Client:
184189 # get or create sync client
185190 @contextmanager
186191 def get_sync_client (self ) -> Generator [httpx .Client , None , None ]:
187- if self .__sync_client :
188- yield self . __sync_client
192+ if client := self .__sync_client . get () :
193+ yield client
189194 else :
190195 client = self ._create_sync_client ()
191196 try :
@@ -200,8 +205,8 @@ def _create_async_client(self) -> httpx.AsyncClient:
200205 # get or create async client
201206 @asynccontextmanager
202207 async def get_async_client (self ) -> AsyncGenerator [httpx .AsyncClient , None ]:
203- if self .__async_client :
204- yield self . __async_client
208+ if client := self .__async_client . get () :
209+ yield client
205210 else :
206211 client = self ._create_async_client ()
207212 try :
0 commit comments