@@ -98,32 +98,33 @@ async def replay_events_after(
98
98
send_callback : EventCallback ,
99
99
) -> StreamId | None :
100
100
"""Replay events after the specified ID."""
101
- # Find the index of the last event ID
102
- start_index = None
103
- for i , ( _ , event_id , _ ) in enumerate ( self ._events ) :
101
+ # Find the stream ID of the last event
102
+ target_stream_id = None
103
+ for stream_id , event_id , _ in self ._events :
104
104
if event_id == last_event_id :
105
- start_index = i + 1
105
+ target_stream_id = stream_id
106
106
break
107
107
108
- if start_index is None :
109
- # If event ID not found, start from beginning
110
- start_index = 0
108
+ if target_stream_id is None :
109
+ # If event ID not found, return None
110
+ return None
111
111
112
- stream_id = None
113
- # Replay events
114
- for _ , event_id , message in self ._events [start_index :]:
115
- await send_callback (EventMessage (message , event_id ))
116
- # Capture the stream ID from the first replayed event
117
- if stream_id is None and len (self ._events ) > start_index :
118
- stream_id = self ._events [start_index ][0 ]
112
+ # Convert last_event_id to int for comparison
113
+ last_event_id_int = int (last_event_id )
119
114
120
- return stream_id
115
+ # Replay only events from the same stream with ID > last_event_id
116
+ for stream_id , event_id , message in self ._events :
117
+ if stream_id == target_stream_id and int (event_id ) > last_event_id_int :
118
+ await send_callback (EventMessage (message , event_id ))
119
+
120
+ return target_stream_id
121
121
122
122
123
123
# Test server implementation that follows MCP protocol
124
124
class ServerTest (Server ):
125
125
def __init__ (self ):
126
126
super ().__init__ (SERVER_NAME )
127
+ self ._lock = None # Will be initialized in async context
127
128
128
129
@self .read_resource ()
129
130
async def handle_read_resource (uri : AnyUrl ) -> str | bytes :
@@ -159,6 +160,16 @@ async def handle_list_tools() -> list[Tool]:
159
160
description = "A tool that triggers server-side sampling" ,
160
161
inputSchema = {"type" : "object" , "properties" : {}},
161
162
),
163
+ Tool (
164
+ name = "wait_for_lock_with_notification" ,
165
+ description = "A tool that sends a notification and waits for lock" ,
166
+ inputSchema = {"type" : "object" , "properties" : {}},
167
+ ),
168
+ Tool (
169
+ name = "release_lock" ,
170
+ description = "A tool that releases the lock" ,
171
+ inputSchema = {"type" : "object" , "properties" : {}},
172
+ ),
162
173
]
163
174
164
175
@self .call_tool ()
@@ -214,6 +225,39 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
214
225
)
215
226
]
216
227
228
+ elif name == "wait_for_lock_with_notification" :
229
+ # Initialize lock if not already done
230
+ if self ._lock is None :
231
+ self ._lock = anyio .Event ()
232
+
233
+ # First send a notification
234
+ await ctx .session .send_log_message (
235
+ level = "info" ,
236
+ data = "First notification before lock" ,
237
+ logger = "lock_tool" ,
238
+ related_request_id = ctx .request_id ,
239
+ )
240
+
241
+ # Now wait for the lock to be released
242
+ await self ._lock .wait ()
243
+
244
+ # Send second notification after lock is released
245
+ await ctx .session .send_log_message (
246
+ level = "info" ,
247
+ data = "Second notification after lock" ,
248
+ logger = "lock_tool" ,
249
+ related_request_id = ctx .request_id ,
250
+ )
251
+
252
+ return [TextContent (type = "text" , text = "Completed" )]
253
+
254
+ elif name == "release_lock" :
255
+ assert self ._lock is not None , "Lock must be initialized before releasing"
256
+
257
+ # Release the lock
258
+ self ._lock .set ()
259
+ return [TextContent (type = "text" , text = "Lock released" )]
260
+
217
261
return [TextContent (type = "text" , text = f"Called { name } " )]
218
262
219
263
@@ -825,7 +869,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
825
869
"""Test client tool invocation."""
826
870
# First list tools
827
871
tools = await initialized_client_session .list_tools ()
828
- assert len (tools .tools ) == 4
872
+ assert len (tools .tools ) == 6
829
873
assert tools .tools [0 ].name == "test_tool"
830
874
831
875
# Call the tool
@@ -862,7 +906,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser
862
906
863
907
# Make multiple requests to verify session persistence
864
908
tools = await session .list_tools ()
865
- assert len (tools .tools ) == 4
909
+ assert len (tools .tools ) == 6
866
910
867
911
# Read a resource
868
912
resource = await session .read_resource (uri = AnyUrl ("foobar://test-persist" ))
@@ -891,7 +935,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se
891
935
892
936
# Check tool listing
893
937
tools = await session .list_tools ()
894
- assert len (tools .tools ) == 4
938
+ assert len (tools .tools ) == 6
895
939
896
940
# Call a tool and verify JSON response handling
897
941
result = await session .call_tool ("test_tool" , {})
@@ -962,7 +1006,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser
962
1006
963
1007
# Make a request to confirm session is working
964
1008
tools = await session .list_tools ()
965
- assert len (tools .tools ) == 4
1009
+ assert len (tools .tools ) == 6
966
1010
967
1011
headers = {}
968
1012
if captured_session_id :
@@ -1026,7 +1070,7 @@ async def mock_delete(self, *args, **kwargs):
1026
1070
1027
1071
# Make a request to confirm session is working
1028
1072
tools = await session .list_tools ()
1029
- assert len (tools .tools ) == 4
1073
+ assert len (tools .tools ) == 6
1030
1074
1031
1075
headers = {}
1032
1076
if captured_session_id :
@@ -1048,32 +1092,32 @@ async def mock_delete(self, *args, **kwargs):
1048
1092
1049
1093
@pytest .mark .anyio
1050
1094
async def test_streamablehttp_client_resumption (event_server ):
1051
- """Test client session to resume a long running tool ."""
1095
+ """Test client session resumption using sync primitives for reliable coordination ."""
1052
1096
_ , server_url = event_server
1053
1097
1054
1098
# Variables to track the state
1055
1099
captured_session_id = None
1056
1100
captured_resumption_token = None
1057
1101
captured_notifications = []
1058
- tool_started = False
1059
1102
captured_protocol_version = None
1103
+ first_notification_received = False
1060
1104
1061
1105
async def message_handler (
1062
1106
message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1063
1107
) -> None :
1064
1108
if isinstance (message , types .ServerNotification ):
1065
1109
captured_notifications .append (message )
1066
- # Look for our special notification that indicates the tool is running
1110
+ # Look for our first notification
1067
1111
if isinstance (message .root , types .LoggingMessageNotification ):
1068
- if message .root .params .data == "Tool started " :
1069
- nonlocal tool_started
1070
- tool_started = True
1112
+ if message .root .params .data == "First notification before lock " :
1113
+ nonlocal first_notification_received
1114
+ first_notification_received = True
1071
1115
1072
1116
async def on_resumption_token_update (token : str ) -> None :
1073
1117
nonlocal captured_resumption_token
1074
1118
captured_resumption_token = token
1075
1119
1076
- # First, start the client session and begin the long-running tool
1120
+ # First, start the client session and begin the tool that waits on lock
1077
1121
async with streamablehttp_client (f"{ server_url } /mcp" , terminate_on_close = False ) as (
1078
1122
read_stream ,
1079
1123
write_stream ,
@@ -1088,7 +1132,7 @@ async def on_resumption_token_update(token: str) -> None:
1088
1132
# Capture the negotiated protocol version
1089
1133
captured_protocol_version = result .protocolVersion
1090
1134
1091
- # Start a long-running tool in a task
1135
+ # Start the tool that will wait on lock in a task
1092
1136
async with anyio .create_task_group () as tg :
1093
1137
1094
1138
async def run_tool ():
@@ -1099,7 +1143,9 @@ async def run_tool():
1099
1143
types .ClientRequest (
1100
1144
types .CallToolRequest (
1101
1145
method = "tools/call" ,
1102
- params = types .CallToolRequestParams (name = "long_running_with_checkpoints" , arguments = {}),
1146
+ params = types .CallToolRequestParams (
1147
+ name = "wait_for_lock_with_notification" , arguments = {}
1148
+ ),
1103
1149
)
1104
1150
),
1105
1151
types .CallToolResult ,
@@ -1108,15 +1154,19 @@ async def run_tool():
1108
1154
1109
1155
tg .start_soon (run_tool )
1110
1156
1111
- # Wait for the tool to start and at least one notification
1112
- # and then kill the task group
1113
- while not tool_started or not captured_resumption_token :
1157
+ # Wait for the first notification and resumption token
1158
+ while not first_notification_received or not captured_resumption_token :
1114
1159
await anyio .sleep (0.1 )
1160
+
1161
+ # Kill the client session while tool is waiting on lock
1115
1162
tg .cancel_scope .cancel ()
1116
1163
1117
- # Store pre notifications and clear the captured notifications
1118
- # for the post-resumption check
1119
- captured_notifications_pre = captured_notifications .copy ()
1164
+ # Verify we received exactly one notification
1165
+ assert len (captured_notifications ) == 1
1166
+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1167
+ assert captured_notifications [0 ].root .params .data == "First notification before lock"
1168
+
1169
+ # Clear notifications for the second phase
1120
1170
captured_notifications = []
1121
1171
1122
1172
# Now resume the session with the same mcp-session-id and protocol version
@@ -1125,54 +1175,48 @@ async def run_tool():
1125
1175
headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1126
1176
if captured_protocol_version :
1127
1177
headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1128
-
1129
1178
async with streamablehttp_client (f"{ server_url } /mcp" , headers = headers ) as (
1130
1179
read_stream ,
1131
1180
write_stream ,
1132
1181
_ ,
1133
1182
):
1134
1183
async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1135
- # Don't initialize - just use the existing session
1136
-
1137
- # Resume the tool with the resumption token
1138
- assert captured_resumption_token is not None
1139
-
1184
+ result = await session .send_request (
1185
+ types .ClientRequest (
1186
+ types .CallToolRequest (
1187
+ method = "tools/call" ,
1188
+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1189
+ )
1190
+ ),
1191
+ types .CallToolResult ,
1192
+ )
1140
1193
metadata = ClientMessageMetadata (
1141
1194
resumption_token = captured_resumption_token ,
1142
1195
)
1196
+
1143
1197
result = await session .send_request (
1144
1198
types .ClientRequest (
1145
1199
types .CallToolRequest (
1146
1200
method = "tools/call" ,
1147
- params = types .CallToolRequestParams (name = "long_running_with_checkpoints " , arguments = {}),
1201
+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification " , arguments = {}),
1148
1202
)
1149
1203
),
1150
1204
types .CallToolResult ,
1151
1205
metadata = metadata ,
1152
1206
)
1153
-
1154
- # We should get a complete result
1155
1207
assert len (result .content ) == 1
1156
1208
assert result .content [0 ].type == "text"
1157
- assert "Completed" in result .content [0 ].text
1209
+ assert result .content [0 ].text == "Completed"
1158
1210
1159
1211
# We should have received the remaining notifications
1160
- assert len (captured_notifications ) > 0
1212
+ assert len (captured_notifications ) == 1
1161
1213
1162
- # Should not have the first notification
1163
- # Check that "Tool started" notification isn't repeated when resuming
1164
- assert not any (
1165
- isinstance (n .root , types .LoggingMessageNotification ) and n .root .params .data == "Tool started"
1166
- for n in captured_notifications
1167
- )
1168
- # there is no intersection between pre and post notifications
1169
- assert not any (n in captured_notifications_pre for n in captured_notifications )
1214
+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
1170
1215
1171
1216
1172
1217
@pytest .mark .anyio
1173
1218
async def test_streamablehttp_server_sampling (basic_server , basic_server_url ):
1174
1219
"""Test server-initiated sampling request through streamable HTTP transport."""
1175
- print ("Testing server sampling..." )
1176
1220
# Variable to track if sampling callback was invoked
1177
1221
sampling_callback_invoked = False
1178
1222
captured_message_params = None
0 commit comments