Skip to content

Commit b344e60

Browse files
add tests
1 parent ac7e954 commit b344e60

File tree

2 files changed

+267
-2
lines changed

2 files changed

+267
-2
lines changed

Lib/asyncio/selector_events.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ def _accept_connection(
209209
raise # The event loop will catch, log and ignore it.
210210
else:
211211
extra = {'peername': addr}
212+
conn_context = context.copy() if context is not None else None
212213
accept = self._accept_connection2(
213214
protocol_factory, conn, extra, sslcontext, server,
214-
ssl_handshake_timeout, ssl_shutdown_timeout, context=context)
215-
self.create_task(accept, context=context)
215+
ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context)
216+
self.create_task(accept, context=conn_context)
216217

217218
async def _accept_connection2(
218219
self, protocol_factory, conn, extra,
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
2+
import asyncio
3+
import contextvars
4+
import unittest
5+
6+
from unittest import TestCase
7+
8+
def tearDownModule():
9+
asyncio.events._set_event_loop_policy(None)
10+
11+
class ServerContextvarsTestCase:
12+
loop_factory = None # To be defined in subclasses
13+
14+
def run_coro(self, coro):
15+
return asyncio.run(coro, loop_factory=self.loop_factory)
16+
17+
def test_start_server1(self):
18+
# Test that asyncio.start_server captures the context at the time of server creation
19+
async def test():
20+
var = contextvars.ContextVar("var", default="default")
21+
22+
async def handle_client(reader, writer):
23+
value = var.get()
24+
writer.write(value.encode())
25+
await writer.drain()
26+
writer.close()
27+
28+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
29+
# change the value
30+
var.set("after_server")
31+
32+
async def client(addr):
33+
reader, writer = await asyncio.open_connection(*addr)
34+
data = await reader.read(100)
35+
writer.close()
36+
await writer.wait_closed()
37+
return data.decode()
38+
39+
async with server:
40+
addr = server.sockets[0].getsockname()
41+
self.assertEqual(await client(addr), "default")
42+
43+
self.assertEqual(var.get(), "after_server")
44+
45+
self.run_coro(test())
46+
47+
def test_start_server2(self):
48+
# Test that mutations to the context in one handler don't affect other handlers or the server's context
49+
async def test():
50+
var = contextvars.ContextVar("var", default="default")
51+
52+
async def handle_client(reader, writer):
53+
value = var.get()
54+
writer.write(value.encode())
55+
var.set("in_handler")
56+
await writer.drain()
57+
writer.close()
58+
59+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
60+
var.set("after_server")
61+
62+
async def client(addr):
63+
reader, writer = await asyncio.open_connection(*addr)
64+
data = await reader.read(100)
65+
writer.close()
66+
await writer.wait_closed()
67+
return data.decode()
68+
69+
async with server:
70+
addr = server.sockets[0].getsockname()
71+
self.assertEqual(await client(addr), "default")
72+
self.assertEqual(await client(addr), "default")
73+
self.assertEqual(await client(addr), "default")
74+
75+
self.assertEqual(var.get(), "after_server")
76+
77+
self.run_coro(test())
78+
79+
def test_start_server3(self):
80+
# Test that mutations to context in concurrent handlers don't affect each other or the server's context
81+
async def test():
82+
var = contextvars.ContextVar("var", default="default")
83+
var.set("before_server")
84+
85+
async def handle_client(reader, writer):
86+
writer.write(var.get().encode())
87+
await writer.drain()
88+
writer.close()
89+
90+
server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
91+
var.set("after_server")
92+
93+
async def client(addr):
94+
reader, writer = await asyncio.open_connection(*addr)
95+
data = await reader.read(100)
96+
self.assertEqual(data.decode(), "before_server")
97+
writer.close()
98+
await writer.wait_closed()
99+
100+
async with server:
101+
addr = server.sockets[0].getsockname()
102+
async with asyncio.TaskGroup() as tg:
103+
for _ in range(100):
104+
tg.create_task(client(addr))
105+
106+
self.assertEqual(var.get(), "after_server")
107+
108+
self.run_coro(test())
109+
110+
def test_create_server1(self):
111+
# Test that loop.create_server captures the context at the time of server creation
112+
# and that mutations to the context in protocol callbacks don't affect the server's context
113+
async def test():
114+
var = contextvars.ContextVar("var", default="default")
115+
116+
class EchoProtocol(asyncio.Protocol):
117+
def connection_made(self, transport):
118+
self.transport = transport
119+
value = var.get()
120+
var.set("in_handler")
121+
self.transport.write(value.encode())
122+
self.transport.close()
123+
124+
server = await asyncio.get_running_loop().create_server(
125+
lambda: EchoProtocol(), '127.0.0.1', 0)
126+
var.set("after_server")
127+
128+
async def client(addr):
129+
reader, writer = await asyncio.open_connection(*addr)
130+
data = await reader.read(100)
131+
self.assertEqual(data.decode(), "default")
132+
writer.close()
133+
await writer.wait_closed()
134+
135+
async with server:
136+
addr = server.sockets[0].getsockname()
137+
await client(addr)
138+
139+
self.assertEqual(var.get(), "after_server")
140+
141+
self.run_coro(test())
142+
143+
def test_create_server2(self):
144+
# Test that mutations to context in one protocol instance don't affect other instances or the server's context
145+
async def test():
146+
var = contextvars.ContextVar("var", default="default")
147+
148+
class EchoProtocol(asyncio.Protocol):
149+
def __init__(self):
150+
super().__init__()
151+
assert var.get() == "default", var.get()
152+
def connection_made(self, transport):
153+
self.transport = transport
154+
value = var.get()
155+
var.set("in_handler")
156+
self.transport.write(value.encode())
157+
self.transport.close()
158+
159+
server = await asyncio.get_running_loop().create_server(
160+
lambda: EchoProtocol(), '127.0.0.1', 0)
161+
162+
var.set("after_server")
163+
164+
async def client(addr, expected):
165+
reader, writer = await asyncio.open_connection(*addr)
166+
data = await reader.read(100)
167+
self.assertEqual(data.decode(), expected)
168+
writer.close()
169+
await writer.wait_closed()
170+
171+
async with server:
172+
addr = server.sockets[0].getsockname()
173+
await client(addr, "default")
174+
await client(addr, "default")
175+
176+
self.assertEqual(var.get(), "after_server")
177+
178+
self.run_coro(test())
179+
180+
def test_gh140947(self):
181+
# See https://github.com/python/cpython/issues/140947
182+
183+
cvar1 = contextvars.ContextVar("cvar1")
184+
cvar2 = contextvars.ContextVar("cvar2")
185+
cvar3 = contextvars.ContextVar("cvar3")
186+
results = {}
187+
188+
def capture_context(meth):
189+
result = []
190+
for k,v in contextvars.copy_context().items():
191+
result.append((k.name, v))
192+
results[meth] = sorted(result)
193+
194+
class DemoProtocol(asyncio.Protocol):
195+
def __init__(self, on_conn_lost):
196+
self.transport = None
197+
self.on_conn_lost = on_conn_lost
198+
self.tasks = set()
199+
200+
def connection_made(self, transport):
201+
capture_context("connection_made")
202+
self.transport = transport
203+
204+
def data_received(self, data):
205+
capture_context("data_received")
206+
207+
task = asyncio.create_task(self.asgi())
208+
self.tasks.add(task)
209+
task.add_done_callback(self.tasks.discard)
210+
211+
self.transport.pause_reading()
212+
213+
def connection_lost(self, exc):
214+
capture_context("connection_lost")
215+
if not self.on_conn_lost.done():
216+
self.on_conn_lost.set_result(True)
217+
218+
async def asgi(self):
219+
capture_context("asgi start")
220+
221+
cvar1.set(True)
222+
223+
# make sure that we only resume after the pause
224+
# otherwise the resume does nothing
225+
while not self.transport._paused:
226+
await asyncio.sleep(0.1)
227+
228+
cvar2.set(True)
229+
230+
self.transport.resume_reading()
231+
232+
cvar3.set(True)
233+
234+
capture_context("asgi end")
235+
236+
237+
async def main():
238+
loop = asyncio.get_running_loop()
239+
on_conn_lost = loop.create_future()
240+
241+
host, port = "127.0.0.1", 8888
242+
243+
async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port):
244+
reader, writer = await asyncio.open_connection(host, port)
245+
writer.write(b"anything")
246+
await writer.drain()
247+
writer.close()
248+
await writer.wait_closed()
249+
await on_conn_lost
250+
self.run_coro(main())
251+
self.assertDictEqual(results, {
252+
"connection_made": [],
253+
"data_received": [],
254+
"asgi start": [],
255+
"asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
256+
"connection_lost": [],
257+
})
258+
259+
260+
class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
261+
loop_factory = staticmethod(asyncio.new_event_loop)
262+
263+
if __name__ == "__main__":
264+
unittest.main()

0 commit comments

Comments
 (0)