Skip to content

Commit ff0d6d4

Browse files
author
Jay Hemnani
committed
fix: auto-reinitialize client session on HTTP 404
Per MCP spec, when the server returns HTTP 404 indicating the session has expired, the client MUST start a new session by sending a new InitializeRequest without a session ID attached. This change implements automatic session recovery: - Add SESSION_EXPIRED error code (-32002) to types.py - Modify transport 404 handling to clear session_id and signal SESSION_EXPIRED for non-initialization requests - Override send_request in ClientSession to catch SESSION_EXPIRED, re-initialize the session, and retry the original request - Prevent infinite loops with _session_recovery_attempted flag - Add comprehensive tests for session recovery scenarios Github-Issue:#1676
1 parent a9cc822 commit ff0d6d4

File tree

4 files changed

+671
-9
lines changed

4 files changed

+671
-9
lines changed

src/mcp/client/session.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
import logging
2-
from typing import Any, Protocol, overload
2+
from typing import Any, Protocol, TypeVar, overload
33

44
import anyio.lowlevel
55
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6-
from pydantic import AnyUrl, TypeAdapter
6+
from pydantic import AnyUrl, BaseModel, TypeAdapter
77
from typing_extensions import deprecated
88

99
import mcp.types as types
1010
from mcp.client.experimental import ExperimentalClientFeatures
1111
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1212
from mcp.shared.context import RequestContext
13+
from mcp.shared.exceptions import McpError
1314
from mcp.shared.message import SessionMessage
14-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
15+
from mcp.shared.session import BaseSession, MessageMetadata, ProgressFnT, RequestResponder
1516
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1617

1718
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
1819

1920
logger = logging.getLogger("client")
2021

22+
# TypeVar for generic result type in send_request (bound to BaseModel like in BaseSession)
23+
_ResultT = TypeVar("_ResultT", bound=BaseModel)
24+
2125

2226
class SamplingFnT(Protocol):
2327
async def __call__(
@@ -195,6 +199,49 @@ async def initialize(self) -> types.InitializeResult:
195199

196200
return result
197201

202+
async def send_request(
203+
self,
204+
request: types.ClientRequest,
205+
result_type: type[_ResultT],
206+
request_read_timeout_seconds: float | None = None,
207+
metadata: MessageMetadata = None,
208+
progress_callback: ProgressFnT | None = None,
209+
*,
210+
_session_recovery_attempted: bool = False,
211+
) -> _ResultT:
212+
"""Send a request with automatic session recovery on expiration.
213+
214+
Per MCP spec, when the server returns 404 indicating the session has
215+
expired, the client MUST re-initialize the session and retry the request.
216+
217+
This override adds that automatic recovery behavior to the base
218+
send_request method.
219+
"""
220+
try:
221+
return await super().send_request(
222+
request,
223+
result_type,
224+
request_read_timeout_seconds,
225+
metadata,
226+
progress_callback,
227+
)
228+
except McpError as e:
229+
# Check if this is a session expired error
230+
if e.error.code == types.SESSION_EXPIRED and not _session_recovery_attempted:
231+
logger.info("Session expired, re-initializing...")
232+
# Re-initialize the session
233+
await self.initialize()
234+
# Retry the original request (with flag to prevent infinite loops)
235+
return await self.send_request(
236+
request,
237+
result_type,
238+
request_read_timeout_seconds,
239+
metadata,
240+
progress_callback,
241+
_session_recovery_attempted=True,
242+
)
243+
raise
244+
198245
def get_server_capabilities(self) -> types.ServerCapabilities | None:
199246
"""Return the server capabilities received during initialization.
200247

src/mcp/client/streamable_http.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from mcp.shared.message import ClientMessageMetadata, SessionMessage
3030
from mcp.types import (
31+
SESSION_EXPIRED,
3132
ErrorData,
3233
InitializeResult,
3334
JSONRPCError,
@@ -347,13 +348,25 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
347348
logger.debug("Received 202 Accepted")
348349
return
349350

350-
if response.status_code == 404: # pragma: no branch
351+
if response.status_code == 404:
352+
# Clear invalid session per MCP spec
353+
self.session_id = None
354+
self.protocol_version = None
355+
351356
if isinstance(message.root, JSONRPCRequest):
352-
await self._send_session_terminated_error( # pragma: no cover
353-
ctx.read_stream_writer, # pragma: no cover
354-
message.root.id, # pragma: no cover
355-
) # pragma: no cover
356-
return # pragma: no cover
357+
if is_initialization:
358+
# For initialization requests, session truly doesn't exist
359+
await self._send_session_terminated_error(
360+
ctx.read_stream_writer,
361+
message.root.id,
362+
)
363+
else:
364+
# For other requests, signal session expired for auto-recovery
365+
await self._send_session_expired_error(
366+
ctx.read_stream_writer,
367+
message.root.id,
368+
)
369+
return
357370

358371
response.raise_for_status()
359372
if is_initialization:
@@ -521,6 +534,23 @@ async def _send_session_terminated_error(
521534
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
522535
await read_stream_writer.send(session_message)
523536

537+
async def _send_session_expired_error(
538+
self,
539+
read_stream_writer: StreamWriter,
540+
request_id: RequestId,
541+
) -> None:
542+
"""Send a session expired error response for auto-recovery."""
543+
jsonrpc_error = JSONRPCError(
544+
jsonrpc="2.0",
545+
id=request_id,
546+
error=ErrorData(
547+
code=SESSION_EXPIRED,
548+
message="Session expired, re-initialization required",
549+
),
550+
)
551+
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
552+
await read_stream_writer.send(session_message)
553+
524554
async def post_writer(
525555
self,
526556
client: httpx.AsyncClient,

src/mcp/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class JSONRPCResponse(BaseModel):
181181
# SDK error codes
182182
CONNECTION_CLOSED = -32000
183183
# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this
184+
SESSION_EXPIRED = -32002
185+
"""Error code indicating the session has expired and needs re-initialization."""
184186

185187
# Standard JSON-RPC error codes
186188
PARSE_ERROR = -32700

0 commit comments

Comments
 (0)