Skip to content

Commit 45dfd5d

Browse files
fix: suppress GeneratorExit during client cleanup
GeneratorExit can leak from sse_client and streamablehttp_client during cleanup, causing RuntimeError in downstream code. This handles both direct GeneratorExit and BaseExceptionGroup wrapping (cpython#95571). Fixes #1214 Signed-off-by: Adrian Cole <adrian@tetrate.io>
1 parent 62575ed commit 45dfd5d

File tree

6 files changed

+180
-13
lines changed

6 files changed

+180
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"pyjwt[crypto]>=2.10.1",
4141
"typing-extensions>=4.13.0",
4242
"typing-inspection>=0.4.1",
43+
"exceptiongroup>=1.0.0; python_version < '3.11'",
4344
]
4445

4546
[project.optional-dependencies]

src/mcp/client/sse.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import logging
2+
import sys
23
from collections.abc import Callable
34
from contextlib import asynccontextmanager
45
from typing import Any
56
from urllib.parse import parse_qs, urljoin, urlparse
67

78
import anyio
9+
10+
if sys.version_info >= (3, 11):
11+
from builtins import BaseExceptionGroup # pragma: lax no cover
12+
else:
13+
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover
814
import httpx
915
from anyio.abc import TaskStatus
1016
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -155,8 +161,19 @@ async def post_writer(endpoint_url: str):
155161

156162
try:
157163
yield read_stream, write_stream
164+
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
165+
# when client code exits the context manager during cancellation.
166+
# See https://github.com/python/cpython/issues/95571
167+
except GeneratorExit:
168+
pass
169+
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
170+
except BaseExceptionGroup as eg:
171+
_, rest = eg.split(GeneratorExit)
172+
if rest:
173+
raise rest from None
158174
finally:
159175
tg.cancel_scope.cancel()
160176
finally:
161177
await read_stream_writer.aclose()
178+
await read_stream.aclose()
162179
await write_stream.aclose()

src/mcp/client/streamable_http.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,17 @@
44

55
import contextlib
66
import logging
7+
import sys
78
from collections.abc import AsyncGenerator, Awaitable, Callable
89
from contextlib import asynccontextmanager
910
from dataclasses import dataclass
1011

1112
import anyio
13+
14+
if sys.version_info >= (3, 11):
15+
from builtins import BaseExceptionGroup # pragma: lax no cover
16+
else:
17+
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover
1218
import httpx
1319
from anyio.abc import TaskGroup
1420
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -570,10 +576,21 @@ def start_get_stream() -> None:
570576

571577
try:
572578
yield read_stream, write_stream
579+
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
580+
# when client code exits the context manager during cancellation.
581+
# See https://github.com/python/cpython/issues/95571
582+
except GeneratorExit:
583+
pass
584+
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
585+
except BaseExceptionGroup as eg:
586+
_, rest = eg.split(GeneratorExit)
587+
if rest:
588+
raise rest from None
573589
finally:
574590
if transport.session_id and terminate_on_close:
575591
await transport.terminate_session(client)
576592
tg.cancel_scope.cancel()
577593
finally:
578594
await read_stream_writer.aclose()
595+
await read_stream.aclose()
579596
await write_stream.aclose()

tests/client/conftest.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,86 @@
1-
from collections.abc import Callable, Generator
1+
import multiprocessing
2+
import socket
3+
from collections.abc import AsyncGenerator, Callable, Generator
24
from contextlib import asynccontextmanager
35
from typing import Any
46
from unittest.mock import patch
57

68
import pytest
9+
import uvicorn
710
from anyio.streams.memory import MemoryObjectSendStream
11+
from starlette.applications import Starlette
12+
from starlette.requests import Request
13+
from starlette.responses import Response
14+
from starlette.routing import Mount, Route
815

916
import mcp.shared.memory
17+
from mcp.server import Server
18+
from mcp.server.sse import SseServerTransport
19+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
1020
from mcp.shared.message import SessionMessage
1121
from mcp.types import JSONRPCNotification, JSONRPCRequest
22+
from tests.test_helpers import wait_for_server
23+
24+
25+
def run_server(port: int) -> None: # pragma: no cover
26+
"""Run server with SSE and Streamable HTTP endpoints."""
27+
server = Server(name="cleanup_test_server")
28+
session_manager = StreamableHTTPSessionManager(app=server, json_response=False)
29+
sse_transport = SseServerTransport("/messages/")
30+
31+
async def handle_sse(request: Request) -> Response:
32+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
33+
if streams:
34+
await server.run(streams[0], streams[1], server.create_initialization_options())
35+
return Response()
36+
37+
@asynccontextmanager
38+
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
39+
async with session_manager.run():
40+
yield
41+
42+
app = Starlette(
43+
routes=[
44+
Route("/sse", endpoint=handle_sse),
45+
Mount("/messages/", app=sse_transport.handle_post_message),
46+
Mount("/mcp", app=session_manager.handle_request),
47+
],
48+
lifespan=lifespan,
49+
)
50+
uvicorn.Server(uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error")).run()
51+
52+
53+
@pytest.fixture
54+
def server_port() -> int:
55+
with socket.socket() as s:
56+
s.bind(("127.0.0.1", 0))
57+
return s.getsockname()[1]
58+
59+
60+
@pytest.fixture
61+
def test_server(server_port: int) -> Generator[str, None, None]:
62+
"""Start server with SSE and Streamable HTTP endpoints."""
63+
proc = multiprocessing.Process(target=run_server, kwargs={"port": server_port}, daemon=True)
64+
proc.start()
65+
wait_for_server(server_port)
66+
try:
67+
yield f"http://127.0.0.1:{server_port}"
68+
finally:
69+
proc.terminate()
70+
proc.join(timeout=2)
71+
if proc.is_alive(): # pragma: no cover
72+
proc.kill()
73+
proc.join(timeout=1)
74+
75+
76+
@pytest.fixture
77+
def sse_server_url(test_server: str) -> str:
78+
return f"{test_server}/sse"
79+
80+
81+
@pytest.fixture
82+
def streamable_server_url(test_server: str) -> str:
83+
return f"{test_server}/mcp"
1284

1385

1486
class SpyMemoryObjectSendStream:
Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,31 @@
1+
import sys
2+
from collections.abc import Callable
13
from typing import Any
4+
5+
if sys.version_info >= (3, 11):
6+
from builtins import BaseExceptionGroup # pragma: lax no cover
7+
else:
8+
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover
9+
210
from unittest.mock import patch
311

412
import anyio
513
import pytest
14+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
615
from pydantic import TypeAdapter
716

17+
from mcp.client.sse import sse_client
18+
from mcp.client.streamable_http import streamable_http_client
819
from mcp.shared.message import SessionMessage
920
from mcp.shared.session import BaseSession, RequestId, SendResultT
1021
from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest
1122

23+
ClientTransport = tuple[
24+
str,
25+
Callable[..., Any],
26+
Callable[[Any], tuple[MemoryObjectReceiveStream[Any], MemoryObjectSendStream[Any]]],
27+
]
28+
1229

1330
@pytest.mark.anyio
1431
async def test_send_request_stream_cleanup():
@@ -17,7 +34,6 @@ async def test_send_request_stream_cleanup():
1734
This test mocks out most of the session functionality to focus on stream cleanup.
1835
"""
1936

20-
# Create a mock session with the minimal required functionality
2137
class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]):
2238
async def _send_response(
2339
self, request_id: RequestId, response: SendResultT | ErrorData
@@ -32,35 +48,77 @@ def _receive_request_adapter(self) -> TypeAdapter[Any]:
3248
def _receive_notification_adapter(self) -> TypeAdapter[Any]:
3349
return TypeAdapter(object) # pragma: no cover
3450

35-
# Create streams
3651
write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
3752
read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
3853

39-
# Create the session
4054
session = TestSession(read_stream_receive, write_stream_send)
4155

42-
# Create a test request
4356
request = PingRequest()
4457

45-
# Patch the _write_stream.send method to raise an exception
4658
async def mock_send(*args: Any, **kwargs: Any):
4759
raise RuntimeError("Simulated network error")
4860

49-
# Record the response streams before the test
5061
initial_stream_count = len(session._response_streams)
5162

52-
# Run the test with the patched method
5363
with patch.object(session._write_stream, "send", mock_send):
5464
with pytest.raises(RuntimeError):
5565
await session.send_request(request, EmptyResult)
5666

57-
# Verify that no response streams were leaked
58-
assert len(session._response_streams) == initial_stream_count, (
59-
f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}"
60-
)
67+
assert len(session._response_streams) == initial_stream_count
6168

62-
# Clean up
6369
await write_stream_send.aclose()
6470
await write_stream_receive.aclose()
6571
await read_stream_send.aclose()
6672
await read_stream_receive.aclose()
73+
74+
75+
@pytest.fixture(params=["sse", "streamable"])
76+
def client_transport(
77+
request: pytest.FixtureRequest, sse_server_url: str, streamable_server_url: str
78+
) -> ClientTransport:
79+
if request.param == "sse":
80+
return (sse_server_url, sse_client, lambda x: (x[0], x[1]))
81+
else:
82+
return (streamable_server_url, streamable_http_client, lambda x: (x[0], x[1]))
83+
84+
85+
@pytest.mark.anyio
86+
async def test_generator_exit_on_gc_cleanup(client_transport: ClientTransport) -> None:
87+
"""Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571)."""
88+
url, client_func, unpack = client_transport
89+
cm = client_func(url)
90+
result = await cm.__aenter__()
91+
read_stream, write_stream = unpack(result)
92+
await cm.gen.aclose()
93+
await read_stream.aclose()
94+
await write_stream.aclose()
95+
96+
97+
@pytest.mark.anyio
98+
async def test_generator_exit_in_exception_group(client_transport: ClientTransport) -> None:
99+
"""Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736)."""
100+
url, client_func, unpack = client_transport
101+
async with client_func(url) as result:
102+
unpack(result)
103+
raise BaseExceptionGroup("unhandled errors in a TaskGroup", [GeneratorExit()])
104+
105+
106+
@pytest.mark.anyio
107+
async def test_generator_exit_mixed_group(client_transport: ClientTransport) -> None:
108+
"""Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736)."""
109+
url, client_func, unpack = client_transport
110+
with pytest.raises(BaseExceptionGroup) as exc_info:
111+
async with client_func(url) as result:
112+
unpack(result)
113+
raise BaseExceptionGroup("errors", [GeneratorExit(), ValueError("real error")])
114+
115+
def has_generator_exit(eg: BaseExceptionGroup[Any]) -> bool:
116+
for e in eg.exceptions:
117+
if isinstance(e, GeneratorExit):
118+
return True # pragma: no cover
119+
if isinstance(e, BaseExceptionGroup):
120+
if has_generator_exit(eg=e): # type: ignore[arg-type]
121+
return True # pragma: no cover
122+
return False
123+
124+
assert not has_generator_exit(exc_info.value)

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)