55
66from unittest import TestCase
77
8+ try :
9+ import ssl
10+ except ImportError :
11+ ssl = None
12+
13+ from test .test_asyncio import utils as test_utils
14+
815def tearDownModule ():
916 asyncio .events ._set_event_loop_policy (None )
1017
1118class ServerContextvarsTestCase :
1219 loop_factory = None # To be defined in subclasses
20+ server_ssl_context = None # To be defined in subclasses for SSL tests
21+ client_ssl_context = None # To be defined in subclasses for SSL tests
1322
1423 def run_coro (self , coro ):
1524 return asyncio .run (coro , loop_factory = self .loop_factory )
@@ -25,12 +34,14 @@ async def handle_client(reader, writer):
2534 await writer .drain ()
2635 writer .close ()
2736
28- server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 )
37+ server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 ,
38+ ssl = self .server_ssl_context )
2939 # change the value
3040 var .set ("after_server" )
3141
3242 async def client (addr ):
33- reader , writer = await asyncio .open_connection (* addr )
43+ reader , writer = await asyncio .open_connection (* addr ,
44+ ssl = self .client_ssl_context )
3445 data = await reader .read (100 )
3546 writer .close ()
3647 await writer .wait_closed ()
@@ -56,11 +67,13 @@ async def handle_client(reader, writer):
5667 await writer .drain ()
5768 writer .close ()
5869
59- server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 )
70+ server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 ,
71+ ssl = self .server_ssl_context )
6072 var .set ("after_server" )
6173
6274 async def client (addr ):
63- reader , writer = await asyncio .open_connection (* addr )
75+ reader , writer = await asyncio .open_connection (* addr ,
76+ ssl = self .client_ssl_context )
6477 data = await reader .read (100 )
6578 writer .close ()
6679 await writer .wait_closed ()
@@ -87,11 +100,13 @@ async def handle_client(reader, writer):
87100 await writer .drain ()
88101 writer .close ()
89102
90- server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 )
103+ server = await asyncio .start_server (handle_client , '127.0.0.1' , 0 ,
104+ ssl = self .server_ssl_context )
91105 var .set ("after_server" )
92106
93107 async def client (addr ):
94- reader , writer = await asyncio .open_connection (* addr )
108+ reader , writer = await asyncio .open_connection (* addr ,
109+ ssl = self .client_ssl_context )
95110 data = await reader .read (100 )
96111 self .assertEqual (data .decode (), "before_server" )
97112 writer .close ()
@@ -122,11 +137,13 @@ def connection_made(self, transport):
122137 self .transport .close ()
123138
124139 server = await asyncio .get_running_loop ().create_server (
125- lambda : EchoProtocol (), '127.0.0.1' , 0 )
140+ lambda : EchoProtocol (), '127.0.0.1' , 0 ,
141+ ssl = self .server_ssl_context )
126142 var .set ("after_server" )
127143
128144 async def client (addr ):
129- reader , writer = await asyncio .open_connection (* addr )
145+ reader , writer = await asyncio .open_connection (* addr ,
146+ ssl = self .client_ssl_context )
130147 data = await reader .read (100 )
131148 self .assertEqual (data .decode (), "default" )
132149 writer .close ()
@@ -157,12 +174,14 @@ def connection_made(self, transport):
157174 self .transport .close ()
158175
159176 server = await asyncio .get_running_loop ().create_server (
160- lambda : EchoProtocol (), '127.0.0.1' , 0 )
177+ lambda : EchoProtocol (), '127.0.0.1' , 0 ,
178+ ssl = self .server_ssl_context )
161179
162180 var .set ("after_server" )
163181
164182 async def client (addr , expected ):
165- reader , writer = await asyncio .open_connection (* addr )
183+ reader , writer = await asyncio .open_connection (* addr ,
184+ ssl = self .client_ssl_context )
166185 data = await reader .read (100 )
167186 self .assertEqual (data .decode (), expected )
168187 writer .close ()
@@ -184,6 +203,7 @@ def test_gh140947(self):
184203 cvar2 = contextvars .ContextVar ("cvar2" )
185204 cvar3 = contextvars .ContextVar ("cvar3" )
186205 results = {}
206+ is_ssl = self .server_ssl_context is not None
187207
188208 def capture_context (meth ):
189209 result = []
@@ -218,36 +238,37 @@ def connection_lost(self, exc):
218238
219239 async def asgi (self ):
220240 capture_context ("asgi start" )
221-
222241 cvar1 .set (True )
223-
224242 # make sure that we only resume after the pause
225243 # otherwise the resume does nothing
226- while not self .transport ._paused :
227- await asyncio .sleep (0.1 )
228-
244+ if is_ssl :
245+ while not self .transport ._ssl_protocol ._app_reading_paused :
246+ await asyncio .sleep (0.01 )
247+ else :
248+ while not self .transport ._paused :
249+ await asyncio .sleep (0.01 )
229250 cvar2 .set (True )
230-
231251 self .transport .resume_reading ()
232-
233252 cvar3 .set (True )
234-
235253 capture_context ("asgi end" )
236254
237-
238255 async def main ():
239256 loop = asyncio .get_running_loop ()
240257 on_conn_lost = loop .create_future ()
241258
242- host , port = "127.0.0.1" , 8888
243-
244- async with await loop .create_server (lambda : DemoProtocol (on_conn_lost ), host , port ):
245- reader , writer = await asyncio .open_connection (host , port )
259+ server = await loop .create_server (
260+ lambda : DemoProtocol (on_conn_lost ), '127.0.0.1' , 0 ,
261+ ssl = self .server_ssl_context )
262+ async with server :
263+ addr = server .sockets [0 ].getsockname ()
264+ reader , writer = await asyncio .open_connection (* addr ,
265+ ssl = self .client_ssl_context )
246266 writer .write (b"anything" )
247267 await writer .drain ()
248268 writer .close ()
249269 await writer .wait_closed ()
250270 await on_conn_lost
271+
251272 self .run_coro (main ())
252273 self .assertDictEqual (results , {
253274 "connection_made" : [],
@@ -261,12 +282,27 @@ async def main():
261282class AsyncioEventLoopTests (TestCase , ServerContextvarsTestCase ):
262283 loop_factory = staticmethod (asyncio .new_event_loop )
263284
285+ @unittest .skipUnless (ssl , "SSL not available" )
286+ class AsyncioEventLoopSSLTests (AsyncioEventLoopTests ):
287+ server_ssl_context = test_utils .simple_server_sslcontext ()
288+ client_ssl_context = test_utils .simple_client_sslcontext ()
289+
264290if sys .platform == "win32" :
265291 class AsyncioProactorEventLoopTests (TestCase , ServerContextvarsTestCase ):
266292 loop_factory = asyncio .ProactorEventLoop
267293
268294 class AsyncioSelectorEventLoopTests (TestCase , ServerContextvarsTestCase ):
269295 loop_factory = asyncio .SelectorEventLoop
270296
297+ @unittest .skipUnless (ssl , "SSL not available" )
298+ class AsyncioProactorEventLoopSSLTests (AsyncioProactorEventLoopTests ):
299+ server_ssl_context = test_utils .simple_server_sslcontext ()
300+ client_ssl_context = test_utils .simple_client_sslcontext ()
301+
302+ @unittest .skipUnless (ssl , "SSL not available" )
303+ class AsyncioSelectorEventLoopSSLTests (AsyncioSelectorEventLoopTests ):
304+ server_ssl_context = test_utils .simple_server_sslcontext ()
305+ client_ssl_context = test_utils .simple_client_sslcontext ()
306+
271307if __name__ == "__main__" :
272308 unittest .main ()
0 commit comments