1+ import sys
2+ from collections .abc import Callable
13from typing import Any
4+
5+ if sys .version_info >= (3 , 11 ):
6+ from builtins import BaseExceptionGroup # pragma: lax no cover
7+ else :
8+ from exceptiongroup import BaseExceptionGroup # pragma: lax no cover
9+
210from unittest .mock import patch
311
412import anyio
513import pytest
14+ from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
615from pydantic import TypeAdapter
716
17+ from mcp .client .sse import sse_client
18+ from mcp .client .streamable_http import streamable_http_client
819from mcp .shared .message import SessionMessage
920from mcp .shared .session import BaseSession , RequestId , SendResultT
1021from mcp .types import ClientNotification , ClientRequest , ClientResult , EmptyResult , ErrorData , PingRequest
1122
23+ ClientTransport = tuple [
24+ str ,
25+ Callable [..., Any ],
26+ Callable [[Any ], tuple [MemoryObjectReceiveStream [Any ], MemoryObjectSendStream [Any ]]],
27+ ]
28+
1229
1330@pytest .mark .anyio
1431async def test_send_request_stream_cleanup ():
@@ -17,7 +34,6 @@ async def test_send_request_stream_cleanup():
1734 This test mocks out most of the session functionality to focus on stream cleanup.
1835 """
1936
20- # Create a mock session with the minimal required functionality
2137 class TestSession (BaseSession [ClientRequest , ClientNotification , ClientResult , Any , Any ]):
2238 async def _send_response (
2339 self , request_id : RequestId , response : SendResultT | ErrorData
@@ -32,35 +48,77 @@ def _receive_request_adapter(self) -> TypeAdapter[Any]:
3248 def _receive_notification_adapter (self ) -> TypeAdapter [Any ]:
3349 return TypeAdapter (object ) # pragma: no cover
3450
35- # Create streams
3651 write_stream_send , write_stream_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
3752 read_stream_send , read_stream_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
3853
39- # Create the session
4054 session = TestSession (read_stream_receive , write_stream_send )
4155
42- # Create a test request
4356 request = PingRequest ()
4457
45- # Patch the _write_stream.send method to raise an exception
4658 async def mock_send (* args : Any , ** kwargs : Any ):
4759 raise RuntimeError ("Simulated network error" )
4860
49- # Record the response streams before the test
5061 initial_stream_count = len (session ._response_streams )
5162
52- # Run the test with the patched method
5363 with patch .object (session ._write_stream , "send" , mock_send ):
5464 with pytest .raises (RuntimeError ):
5565 await session .send_request (request , EmptyResult )
5666
57- # Verify that no response streams were leaked
58- assert len (session ._response_streams ) == initial_stream_count , (
59- f"Expected { initial_stream_count } response streams after request, but found { len (session ._response_streams )} "
60- )
67+ assert len (session ._response_streams ) == initial_stream_count
6168
62- # Clean up
6369 await write_stream_send .aclose ()
6470 await write_stream_receive .aclose ()
6571 await read_stream_send .aclose ()
6672 await read_stream_receive .aclose ()
73+
74+
75+ @pytest .fixture (params = ["sse" , "streamable" ])
76+ def client_transport (
77+ request : pytest .FixtureRequest , sse_server_url : str , streamable_server_url : str
78+ ) -> ClientTransport :
79+ if request .param == "sse" :
80+ return (sse_server_url , sse_client , lambda x : (x [0 ], x [1 ]))
81+ else :
82+ return (streamable_server_url , streamable_http_client , lambda x : (x [0 ], x [1 ]))
83+
84+
85+ @pytest .mark .anyio
86+ async def test_generator_exit_on_gc_cleanup (client_transport : ClientTransport ) -> None :
87+ """Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571)."""
88+ url , client_func , unpack = client_transport
89+ cm = client_func (url )
90+ result = await cm .__aenter__ ()
91+ read_stream , write_stream = unpack (result )
92+ await cm .gen .aclose ()
93+ await read_stream .aclose ()
94+ await write_stream .aclose ()
95+
96+
97+ @pytest .mark .anyio
98+ async def test_generator_exit_in_exception_group (client_transport : ClientTransport ) -> None :
99+ """Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736)."""
100+ url , client_func , unpack = client_transport
101+ async with client_func (url ) as result :
102+ unpack (result )
103+ raise BaseExceptionGroup ("unhandled errors in a TaskGroup" , [GeneratorExit ()])
104+
105+
106+ @pytest .mark .anyio
107+ async def test_generator_exit_mixed_group (client_transport : ClientTransport ) -> None :
108+ """Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736)."""
109+ url , client_func , unpack = client_transport
110+ with pytest .raises (BaseExceptionGroup ) as exc_info :
111+ async with client_func (url ) as result :
112+ unpack (result )
113+ raise BaseExceptionGroup ("errors" , [GeneratorExit (), ValueError ("real error" )])
114+
115+ def has_generator_exit (eg : BaseExceptionGroup [Any ]) -> bool :
116+ for e in eg .exceptions :
117+ if isinstance (e , GeneratorExit ):
118+ return True # pragma: no cover
119+ if isinstance (e , BaseExceptionGroup ):
120+ if has_generator_exit (eg = e ): # type: ignore[arg-type]
121+ return True # pragma: no cover
122+ return False
123+
124+ assert not has_generator_exit (exc_info .value )
0 commit comments