Skip to content

Commit 09df55e

Browse files
fix ssl part
1 parent ec0ed9b commit 09df55e

File tree

3 files changed

+68
-28
lines changed

3 files changed

+68
-28
lines changed

Lib/asyncio/base_events.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,8 @@ def _make_ssl_transport(
511511
extra=None, server=None,
512512
ssl_handshake_timeout=None,
513513
ssl_shutdown_timeout=None,
514-
call_connection_made=True):
514+
call_connection_made=True,
515+
context=None):
515516
"""Create SSL transport."""
516517
raise NotImplementedError
517518

@@ -1228,7 +1229,8 @@ async def _create_connection_transport(
12281229
sock, protocol, sslcontext, waiter,
12291230
server_side=server_side, server_hostname=server_hostname,
12301231
ssl_handshake_timeout=ssl_handshake_timeout,
1231-
ssl_shutdown_timeout=ssl_shutdown_timeout)
1232+
ssl_shutdown_timeout=ssl_shutdown_timeout,
1233+
context=context)
12321234
else:
12331235
transport = self._make_socket_transport(sock, protocol, waiter, context=context)
12341236

Lib/asyncio/selector_events.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,17 @@ def _make_ssl_transport(
7878
extra=None, server=None,
7979
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
8080
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
81+
context=None,
8182
):
8283
self._ensure_fd_no_transport(rawsock)
8384
ssl_protocol = sslproto.SSLProtocol(
8485
self, protocol, sslcontext, waiter,
8586
server_side, server_hostname,
8687
ssl_handshake_timeout=ssl_handshake_timeout,
87-
ssl_shutdown_timeout=ssl_shutdown_timeout
88+
ssl_shutdown_timeout=ssl_shutdown_timeout,
8889
)
8990
_SelectorSocketTransport(self, rawsock, ssl_protocol,
90-
extra=extra, server=server)
91+
extra=extra, server=server, context=context)
9192
return ssl_protocol._app_transport
9293

9394
def _make_datagram_transport(self, sock, protocol,
@@ -230,7 +231,8 @@ async def _accept_connection2(
230231
conn, protocol, sslcontext, waiter=waiter,
231232
server_side=True, extra=extra, server=server,
232233
ssl_handshake_timeout=ssl_handshake_timeout,
233-
ssl_shutdown_timeout=ssl_shutdown_timeout)
234+
ssl_shutdown_timeout=ssl_shutdown_timeout,
235+
context=context)
234236
else:
235237
transport = self._make_socket_transport(
236238
conn, protocol, waiter=waiter, extra=extra,

Lib/test/test_asyncio/test_server_context.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@
55

66
from unittest import TestCase
77

8+
try:
9+
import ssl
10+
except ImportError:
11+
ssl = None
12+
13+
from test.test_asyncio import utils as test_utils
14+
815
def tearDownModule():
916
asyncio.events._set_event_loop_policy(None)
1017

1118
class ServerContextvarsTestCase:
1219
loop_factory = None # To be defined in subclasses
20+
server_ssl_context = None # To be defined in subclasses for SSL tests
21+
client_ssl_context = None # To be defined in subclasses for SSL tests
1322

1423
def run_coro(self, coro):
1524
return asyncio.run(coro, loop_factory=self.loop_factory)
@@ -25,12 +34,14 @@ async def handle_client(reader, writer):
2534
await writer.drain()
2635
writer.close()
2736

28-
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
37+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
38+
ssl=self.server_ssl_context)
2939
# change the value
3040
var.set("after_server")
3141

3242
async def client(addr):
33-
reader, writer = await asyncio.open_connection(*addr)
43+
reader, writer = await asyncio.open_connection(*addr,
44+
ssl=self.client_ssl_context)
3445
data = await reader.read(100)
3546
writer.close()
3647
await writer.wait_closed()
@@ -56,11 +67,13 @@ async def handle_client(reader, writer):
5667
await writer.drain()
5768
writer.close()
5869

59-
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
70+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
71+
ssl=self.server_ssl_context)
6072
var.set("after_server")
6173

6274
async def client(addr):
63-
reader, writer = await asyncio.open_connection(*addr)
75+
reader, writer = await asyncio.open_connection(*addr,
76+
ssl=self.client_ssl_context)
6477
data = await reader.read(100)
6578
writer.close()
6679
await writer.wait_closed()
@@ -87,11 +100,13 @@ async def handle_client(reader, writer):
87100
await writer.drain()
88101
writer.close()
89102

90-
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
103+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
104+
ssl=self.server_ssl_context)
91105
var.set("after_server")
92106

93107
async def client(addr):
94-
reader, writer = await asyncio.open_connection(*addr)
108+
reader, writer = await asyncio.open_connection(*addr,
109+
ssl=self.client_ssl_context)
95110
data = await reader.read(100)
96111
self.assertEqual(data.decode(), "before_server")
97112
writer.close()
@@ -122,11 +137,13 @@ def connection_made(self, transport):
122137
self.transport.close()
123138

124139
server = await asyncio.get_running_loop().create_server(
125-
lambda: EchoProtocol(), '127.0.0.1', 0)
140+
lambda: EchoProtocol(), '127.0.0.1', 0,
141+
ssl=self.server_ssl_context)
126142
var.set("after_server")
127143

128144
async def client(addr):
129-
reader, writer = await asyncio.open_connection(*addr)
145+
reader, writer = await asyncio.open_connection(*addr,
146+
ssl=self.client_ssl_context)
130147
data = await reader.read(100)
131148
self.assertEqual(data.decode(), "default")
132149
writer.close()
@@ -157,12 +174,14 @@ def connection_made(self, transport):
157174
self.transport.close()
158175

159176
server = await asyncio.get_running_loop().create_server(
160-
lambda: EchoProtocol(), '127.0.0.1', 0)
177+
lambda: EchoProtocol(), '127.0.0.1', 0,
178+
ssl=self.server_ssl_context)
161179

162180
var.set("after_server")
163181

164182
async def client(addr, expected):
165-
reader, writer = await asyncio.open_connection(*addr)
183+
reader, writer = await asyncio.open_connection(*addr,
184+
ssl=self.client_ssl_context)
166185
data = await reader.read(100)
167186
self.assertEqual(data.decode(), expected)
168187
writer.close()
@@ -184,6 +203,7 @@ def test_gh140947(self):
184203
cvar2 = contextvars.ContextVar("cvar2")
185204
cvar3 = contextvars.ContextVar("cvar3")
186205
results = {}
206+
is_ssl = self.server_ssl_context is not None
187207

188208
def capture_context(meth):
189209
result = []
@@ -218,36 +238,37 @@ def connection_lost(self, exc):
218238

219239
async def asgi(self):
220240
capture_context("asgi start")
221-
222241
cvar1.set(True)
223-
224242
# make sure that we only resume after the pause
225243
# otherwise the resume does nothing
226-
while not self.transport._paused:
227-
await asyncio.sleep(0.1)
228-
244+
if is_ssl:
245+
while not self.transport._ssl_protocol._app_reading_paused:
246+
await asyncio.sleep(0.01)
247+
else:
248+
while not self.transport._paused:
249+
await asyncio.sleep(0.01)
229250
cvar2.set(True)
230-
231251
self.transport.resume_reading()
232-
233252
cvar3.set(True)
234-
235253
capture_context("asgi end")
236254

237-
238255
async def main():
239256
loop = asyncio.get_running_loop()
240257
on_conn_lost = loop.create_future()
241258

242-
host, port = "127.0.0.1", 8888
243-
244-
async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port):
245-
reader, writer = await asyncio.open_connection(host, port)
259+
server = await loop.create_server(
260+
lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
261+
ssl=self.server_ssl_context)
262+
async with server:
263+
addr = server.sockets[0].getsockname()
264+
reader, writer = await asyncio.open_connection(*addr,
265+
ssl=self.client_ssl_context)
246266
writer.write(b"anything")
247267
await writer.drain()
248268
writer.close()
249269
await writer.wait_closed()
250270
await on_conn_lost
271+
251272
self.run_coro(main())
252273
self.assertDictEqual(results, {
253274
"connection_made": [],
@@ -261,12 +282,27 @@ async def main():
261282
class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
262283
loop_factory = staticmethod(asyncio.new_event_loop)
263284

285+
@unittest.skipUnless(ssl, "SSL not available")
286+
class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
287+
server_ssl_context = test_utils.simple_server_sslcontext()
288+
client_ssl_context = test_utils.simple_client_sslcontext()
289+
264290
if sys.platform == "win32":
265291
class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
266292
loop_factory = asyncio.ProactorEventLoop
267293

268294
class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
269295
loop_factory = asyncio.SelectorEventLoop
270296

297+
@unittest.skipUnless(ssl, "SSL not available")
298+
class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
299+
server_ssl_context = test_utils.simple_server_sslcontext()
300+
client_ssl_context = test_utils.simple_client_sslcontext()
301+
302+
@unittest.skipUnless(ssl, "SSL not available")
303+
class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
304+
server_ssl_context = test_utils.simple_server_sslcontext()
305+
client_ssl_context = test_utils.simple_client_sslcontext()
306+
271307
if __name__ == "__main__":
272308
unittest.main()

0 commit comments

Comments
 (0)