Skip to content

Commit 813da6a

Browse files
authored
fix flaky fix-test_streamablehttp_client_resumption test (#1166)
1 parent 11162d7 commit 813da6a

File tree

2 files changed

+100
-58
lines changed

2 files changed

+100
-58
lines changed

src/mcp/server/streamable_http.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,7 @@ async def message_router():
837837
response_id = str(message.root.id)
838838
# If this response is for an existing request stream,
839839
# send it there
840-
if response_id in self._request_streams:
841-
target_request_id = response_id
842-
840+
target_request_id = response_id
843841
else:
844842
# Extract related_request_id from meta if it exists
845843
if (

tests/shared/test_streamable_http.py

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -98,32 +98,33 @@ async def replay_events_after(
9898
send_callback: EventCallback,
9999
) -> StreamId | None:
100100
"""Replay events after the specified ID."""
101-
# Find the index of the last event ID
102-
start_index = None
103-
for i, (_, event_id, _) in enumerate(self._events):
101+
# Find the stream ID of the last event
102+
target_stream_id = None
103+
for stream_id, event_id, _ in self._events:
104104
if event_id == last_event_id:
105-
start_index = i + 1
105+
target_stream_id = stream_id
106106
break
107107

108-
if start_index is None:
109-
# If event ID not found, start from beginning
110-
start_index = 0
108+
if target_stream_id is None:
109+
# If event ID not found, return None
110+
return None
111111

112-
stream_id = None
113-
# Replay events
114-
for _, event_id, message in self._events[start_index:]:
115-
await send_callback(EventMessage(message, event_id))
116-
# Capture the stream ID from the first replayed event
117-
if stream_id is None and len(self._events) > start_index:
118-
stream_id = self._events[start_index][0]
112+
# Convert last_event_id to int for comparison
113+
last_event_id_int = int(last_event_id)
119114

120-
return stream_id
115+
# Replay only events from the same stream with ID > last_event_id
116+
for stream_id, event_id, message in self._events:
117+
if stream_id == target_stream_id and int(event_id) > last_event_id_int:
118+
await send_callback(EventMessage(message, event_id))
119+
120+
return target_stream_id
121121

122122

123123
# Test server implementation that follows MCP protocol
124124
class ServerTest(Server):
125125
def __init__(self):
126126
super().__init__(SERVER_NAME)
127+
self._lock = None # Will be initialized in async context
127128

128129
@self.read_resource()
129130
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
@@ -159,6 +160,16 @@ async def handle_list_tools() -> list[Tool]:
159160
description="A tool that triggers server-side sampling",
160161
inputSchema={"type": "object", "properties": {}},
161162
),
163+
Tool(
164+
name="wait_for_lock_with_notification",
165+
description="A tool that sends a notification and waits for lock",
166+
inputSchema={"type": "object", "properties": {}},
167+
),
168+
Tool(
169+
name="release_lock",
170+
description="A tool that releases the lock",
171+
inputSchema={"type": "object", "properties": {}},
172+
),
162173
]
163174

164175
@self.call_tool()
@@ -214,6 +225,39 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
214225
)
215226
]
216227

228+
elif name == "wait_for_lock_with_notification":
229+
# Initialize lock if not already done
230+
if self._lock is None:
231+
self._lock = anyio.Event()
232+
233+
# First send a notification
234+
await ctx.session.send_log_message(
235+
level="info",
236+
data="First notification before lock",
237+
logger="lock_tool",
238+
related_request_id=ctx.request_id,
239+
)
240+
241+
# Now wait for the lock to be released
242+
await self._lock.wait()
243+
244+
# Send second notification after lock is released
245+
await ctx.session.send_log_message(
246+
level="info",
247+
data="Second notification after lock",
248+
logger="lock_tool",
249+
related_request_id=ctx.request_id,
250+
)
251+
252+
return [TextContent(type="text", text="Completed")]
253+
254+
elif name == "release_lock":
255+
assert self._lock is not None, "Lock must be initialized before releasing"
256+
257+
# Release the lock
258+
self._lock.set()
259+
return [TextContent(type="text", text="Lock released")]
260+
217261
return [TextContent(type="text", text=f"Called {name}")]
218262

219263

@@ -825,7 +869,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
825869
"""Test client tool invocation."""
826870
# First list tools
827871
tools = await initialized_client_session.list_tools()
828-
assert len(tools.tools) == 4
872+
assert len(tools.tools) == 6
829873
assert tools.tools[0].name == "test_tool"
830874

831875
# Call the tool
@@ -862,7 +906,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser
862906

863907
# Make multiple requests to verify session persistence
864908
tools = await session.list_tools()
865-
assert len(tools.tools) == 4
909+
assert len(tools.tools) == 6
866910

867911
# Read a resource
868912
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -891,7 +935,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se
891935

892936
# Check tool listing
893937
tools = await session.list_tools()
894-
assert len(tools.tools) == 4
938+
assert len(tools.tools) == 6
895939

896940
# Call a tool and verify JSON response handling
897941
result = await session.call_tool("test_tool", {})
@@ -962,7 +1006,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser
9621006

9631007
# Make a request to confirm session is working
9641008
tools = await session.list_tools()
965-
assert len(tools.tools) == 4
1009+
assert len(tools.tools) == 6
9661010

9671011
headers = {}
9681012
if captured_session_id:
@@ -1026,7 +1070,7 @@ async def mock_delete(self, *args, **kwargs):
10261070

10271071
# Make a request to confirm session is working
10281072
tools = await session.list_tools()
1029-
assert len(tools.tools) == 4
1073+
assert len(tools.tools) == 6
10301074

10311075
headers = {}
10321076
if captured_session_id:
@@ -1048,32 +1092,32 @@ async def mock_delete(self, *args, **kwargs):
10481092

10491093
@pytest.mark.anyio
10501094
async def test_streamablehttp_client_resumption(event_server):
1051-
"""Test client session to resume a long running tool."""
1095+
"""Test client session resumption using sync primitives for reliable coordination."""
10521096
_, server_url = event_server
10531097

10541098
# Variables to track the state
10551099
captured_session_id = None
10561100
captured_resumption_token = None
10571101
captured_notifications = []
1058-
tool_started = False
10591102
captured_protocol_version = None
1103+
first_notification_received = False
10601104

10611105
async def message_handler(
10621106
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
10631107
) -> None:
10641108
if isinstance(message, types.ServerNotification):
10651109
captured_notifications.append(message)
1066-
# Look for our special notification that indicates the tool is running
1110+
# Look for our first notification
10671111
if isinstance(message.root, types.LoggingMessageNotification):
1068-
if message.root.params.data == "Tool started":
1069-
nonlocal tool_started
1070-
tool_started = True
1112+
if message.root.params.data == "First notification before lock":
1113+
nonlocal first_notification_received
1114+
first_notification_received = True
10711115

10721116
async def on_resumption_token_update(token: str) -> None:
10731117
nonlocal captured_resumption_token
10741118
captured_resumption_token = token
10751119

1076-
# First, start the client session and begin the long-running tool
1120+
# First, start the client session and begin the tool that waits on lock
10771121
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
10781122
read_stream,
10791123
write_stream,
@@ -1088,7 +1132,7 @@ async def on_resumption_token_update(token: str) -> None:
10881132
# Capture the negotiated protocol version
10891133
captured_protocol_version = result.protocolVersion
10901134

1091-
# Start a long-running tool in a task
1135+
# Start the tool that will wait on lock in a task
10921136
async with anyio.create_task_group() as tg:
10931137

10941138
async def run_tool():
@@ -1099,7 +1143,9 @@ async def run_tool():
10991143
types.ClientRequest(
11001144
types.CallToolRequest(
11011145
method="tools/call",
1102-
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
1146+
params=types.CallToolRequestParams(
1147+
name="wait_for_lock_with_notification", arguments={}
1148+
),
11031149
)
11041150
),
11051151
types.CallToolResult,
@@ -1108,15 +1154,19 @@ async def run_tool():
11081154

11091155
tg.start_soon(run_tool)
11101156

1111-
# Wait for the tool to start and at least one notification
1112-
# and then kill the task group
1113-
while not tool_started or not captured_resumption_token:
1157+
# Wait for the first notification and resumption token
1158+
while not first_notification_received or not captured_resumption_token:
11141159
await anyio.sleep(0.1)
1160+
1161+
# Kill the client session while tool is waiting on lock
11151162
tg.cancel_scope.cancel()
11161163

1117-
# Store pre notifications and clear the captured notifications
1118-
# for the post-resumption check
1119-
captured_notifications_pre = captured_notifications.copy()
1164+
# Verify we received exactly one notification
1165+
assert len(captured_notifications) == 1
1166+
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
1167+
assert captured_notifications[0].root.params.data == "First notification before lock"
1168+
1169+
# Clear notifications for the second phase
11201170
captured_notifications = []
11211171

11221172
# Now resume the session with the same mcp-session-id and protocol version
@@ -1125,54 +1175,48 @@ async def run_tool():
11251175
headers[MCP_SESSION_ID_HEADER] = captured_session_id
11261176
if captured_protocol_version:
11271177
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1128-
11291178
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
11301179
read_stream,
11311180
write_stream,
11321181
_,
11331182
):
11341183
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1135-
# Don't initialize - just use the existing session
1136-
1137-
# Resume the tool with the resumption token
1138-
assert captured_resumption_token is not None
1139-
1184+
result = await session.send_request(
1185+
types.ClientRequest(
1186+
types.CallToolRequest(
1187+
method="tools/call",
1188+
params=types.CallToolRequestParams(name="release_lock", arguments={}),
1189+
)
1190+
),
1191+
types.CallToolResult,
1192+
)
11401193
metadata = ClientMessageMetadata(
11411194
resumption_token=captured_resumption_token,
11421195
)
1196+
11431197
result = await session.send_request(
11441198
types.ClientRequest(
11451199
types.CallToolRequest(
11461200
method="tools/call",
1147-
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
1201+
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
11481202
)
11491203
),
11501204
types.CallToolResult,
11511205
metadata=metadata,
11521206
)
1153-
1154-
# We should get a complete result
11551207
assert len(result.content) == 1
11561208
assert result.content[0].type == "text"
1157-
assert "Completed" in result.content[0].text
1209+
assert result.content[0].text == "Completed"
11581210

11591211
# We should have received the remaining notifications
1160-
assert len(captured_notifications) > 0
1212+
assert len(captured_notifications) == 1
11611213

1162-
# Should not have the first notification
1163-
# Check that "Tool started" notification isn't repeated when resuming
1164-
assert not any(
1165-
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
1166-
for n in captured_notifications
1167-
)
1168-
# there is no intersection between pre and post notifications
1169-
assert not any(n in captured_notifications_pre for n in captured_notifications)
1214+
assert captured_notifications[0].root.params.data == "Second notification after lock"
11701215

11711216

11721217
@pytest.mark.anyio
11731218
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
11741219
"""Test server-initiated sampling request through streamable HTTP transport."""
1175-
print("Testing server sampling...")
11761220
# Variable to track if sampling callback was invoked
11771221
sampling_callback_invoked = False
11781222
captured_message_params = None

0 commit comments

Comments
 (0)