Skip to content

Commit f501dcc

Browse files
author
skyvanguard
committed
fix: handle ClientDisconnect gracefully instead of returning HTTP 500
When a client disconnects during a request (network timeout, user cancels, load balancer timeout, mobile network interruption), the server was catching the exception with a broad `except Exception` handler, logging it as ERROR with full traceback, and returning HTTP 500. ClientDisconnect is a client-side event, not a server failure. This change catches it explicitly at the request dispatch level and in SSE stream handlers, logging at DEBUG level instead. Changes: - Import ClientDisconnect from starlette.requests - Add except ClientDisconnect handler in handle_request() to catch disconnects across all HTTP methods (POST, GET, DELETE) - Add handlers in _handle_get_request SSE streams and event replay to prevent ERROR logging on client disconnect - Add regression tests verifying no ERROR logs are produced and server remains healthy after client disconnection Github-Issue: #1648 Reported-by: FanisPapakonstantinou
1 parent a7ddfda commit f501dcc

File tree

2 files changed

+229
-9
lines changed

2 files changed

+229
-9
lines changed

src/mcp/server/streamable_http.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2121
from pydantic import ValidationError
2222
from sse_starlette import EventSourceResponse
23-
from starlette.requests import Request
23+
from starlette.requests import ClientDisconnect, Request
2424
from starlette.responses import Response
2525
from starlette.types import Receive, Scope, Send
2626

@@ -379,14 +379,17 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
379379
await response(scope, receive, send)
380380
return
381381

382-
if request.method == "POST":
383-
await self._handle_post_request(scope, request, receive, send)
384-
elif request.method == "GET": # pragma: no cover
385-
await self._handle_get_request(request, send)
386-
elif request.method == "DELETE": # pragma: no cover
387-
await self._handle_delete_request(request, send)
388-
else: # pragma: no cover
389-
await self._handle_unsupported_request(request, send)
382+
try:
383+
if request.method == "POST":
384+
await self._handle_post_request(scope, request, receive, send)
385+
elif request.method == "GET": # pragma: no cover
386+
await self._handle_get_request(request, send)
387+
elif request.method == "DELETE": # pragma: no cover
388+
await self._handle_delete_request(request, send)
389+
else: # pragma: no cover
390+
await self._handle_unsupported_request(request, send)
391+
except ClientDisconnect:
392+
logger.debug(f"Client disconnected during {request.method} request")
390393

391394
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
392395
"""Check if the request accepts the required media types."""
@@ -704,6 +707,8 @@ async def standalone_sse_writer():
704707
# Send the message via SSE
705708
event_data = self._create_event_data(event_message)
706709
await sse_stream_writer.send(event_data)
710+
except ClientDisconnect:
711+
logger.debug("Client disconnected from standalone SSE stream")
707712
except Exception:
708713
logger.exception("Error in standalone SSE writer")
709714
finally:
@@ -720,6 +725,11 @@ async def standalone_sse_writer():
720725
try:
721726
# This will send headers immediately and establish the SSE connection
722727
await response(request.scope, request.receive, send)
728+
except ClientDisconnect:
729+
logger.debug("Client disconnected from GET SSE stream")
730+
await sse_stream_writer.aclose()
731+
await sse_stream_reader.aclose()
732+
await self._clean_up_memory_streams(GET_STREAM_KEY)
723733
except Exception:
724734
logger.exception("Error in standalone SSE response")
725735
await sse_stream_writer.aclose()
@@ -910,6 +920,8 @@ async def send_event(event_message: EventMessage) -> None:
910920
except anyio.ClosedResourceError:
911921
# Expected when close_sse_stream() is called
912922
logger.debug("Replay SSE stream closed by close_sse_stream()")
923+
except ClientDisconnect:
924+
logger.debug("Client disconnected during event replay")
913925
except Exception:
914926
logger.exception("Error in replay sender")
915927

@@ -922,12 +934,16 @@ async def send_event(event_message: EventMessage) -> None:
922934

923935
try:
924936
await response(request.scope, request.receive, send)
937+
except ClientDisconnect:
938+
logger.debug("Client disconnected during replay response")
925939
except Exception:
926940
logger.exception("Error in replay response")
927941
finally:
928942
await sse_stream_writer.aclose()
929943
await sse_stream_reader.aclose()
930944

945+
except ClientDisconnect:
946+
logger.debug("Client disconnected during event replay request")
931947
except Exception:
932948
logger.exception("Error replaying events")
933949
response = self._create_error_response(
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""Test for issue #1648 - ClientDisconnect returns HTTP 500.
2+
3+
When a client disconnects during a request (network timeout, user cancels, load
4+
balancer timeout, mobile network interruption), the server should handle this
5+
gracefully instead of returning HTTP 500 and logging as ERROR.
6+
7+
ClientDisconnect is a client-side event, not a server failure.
8+
"""
9+
10+
import logging
11+
import threading
12+
from collections.abc import AsyncGenerator
13+
from contextlib import asynccontextmanager
14+
15+
import anyio
16+
import httpx
17+
import pytest
18+
from starlette.applications import Starlette
19+
from starlette.routing import Mount
20+
21+
from mcp.server import Server
22+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
23+
from mcp.types import Tool
24+
25+
SERVER_NAME = "test_client_disconnect_server"
26+
27+
28+
class SlowServer(Server):
29+
"""Server with a slow tool to allow time for client disconnect."""
30+
31+
def __init__(self):
32+
super().__init__(SERVER_NAME)
33+
34+
@self.list_tools()
35+
async def handle_list_tools() -> list[Tool]:
36+
return [
37+
Tool(
38+
name="slow_tool",
39+
description="A tool that takes time to respond",
40+
input_schema={"type": "object", "properties": {}},
41+
),
42+
]
43+
44+
@self.call_tool()
45+
async def handle_call_tool(name: str, arguments: dict) -> list:
46+
if name == "slow_tool":
47+
await anyio.sleep(10)
48+
return [{"type": "text", "text": "done"}]
49+
raise ValueError(f"Unknown tool: {name}")
50+
51+
52+
def create_app() -> Starlette:
53+
"""Create a Starlette application for testing."""
54+
server = SlowServer()
55+
session_manager = StreamableHTTPSessionManager(
56+
app=server,
57+
json_response=True,
58+
stateless=True,
59+
)
60+
61+
@asynccontextmanager
62+
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
63+
async with session_manager.run():
64+
yield
65+
66+
routes = [Mount("/", app=session_manager.handle_request)]
67+
return Starlette(routes=routes, lifespan=lifespan)
68+
69+
70+
class ServerThread(threading.Thread):
71+
"""Thread that runs the ASGI application lifespan."""
72+
73+
def __init__(self, app: Starlette):
74+
super().__init__(daemon=True)
75+
self.app = app
76+
self._stop_event = threading.Event()
77+
78+
def run(self) -> None:
79+
async def run_lifespan():
80+
lifespan_context = getattr(self.app.router, "lifespan_context", None)
81+
assert lifespan_context is not None
82+
async with lifespan_context(self.app):
83+
while not self._stop_event.is_set():
84+
await anyio.sleep(0.1)
85+
86+
anyio.run(run_lifespan)
87+
88+
def stop(self) -> None:
89+
self._stop_event.set()
90+
91+
92+
@pytest.mark.anyio
93+
async def test_client_disconnect_does_not_produce_500(caplog: pytest.LogCaptureFixture):
94+
"""Client disconnect should not produce HTTP 500 or ERROR log entries.
95+
96+
Regression test for issue #1648: when a client disconnects mid-request,
97+
the server was catching the exception with a broad `except Exception` handler,
98+
logging it as ERROR, and returning HTTP 500.
99+
"""
100+
app = create_app()
101+
server_thread = ServerThread(app)
102+
server_thread.start()
103+
104+
try:
105+
await anyio.sleep(0.2)
106+
107+
with caplog.at_level(logging.DEBUG):
108+
async with httpx.AsyncClient(
109+
transport=httpx.ASGITransport(app=app),
110+
base_url="http://testserver",
111+
timeout=1.0,
112+
) as client:
113+
# Send a tool call that will take a long time, client will timeout
114+
try:
115+
await client.post(
116+
"/",
117+
json={
118+
"jsonrpc": "2.0",
119+
"method": "tools/call",
120+
"id": "call-1",
121+
"params": {"name": "slow_tool", "arguments": {}},
122+
},
123+
headers={
124+
"Accept": "application/json, text/event-stream",
125+
"Content-Type": "application/json",
126+
},
127+
)
128+
except (httpx.ReadTimeout, httpx.ReadError):
129+
pass # Expected - client timed out
130+
131+
# Wait briefly for any async error logging to complete
132+
await anyio.sleep(0.1)
133+
134+
# Verify no ERROR-level log entries about handling POST requests
135+
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR and "POST" in r.getMessage()]
136+
assert not error_records, (
137+
f"Server logged ERROR for client disconnect: {[r.getMessage() for r in error_records]}"
138+
)
139+
finally:
140+
server_thread.stop()
141+
server_thread.join(timeout=2)
142+
143+
144+
@pytest.mark.anyio
145+
async def test_server_healthy_after_client_disconnect():
146+
"""Server should remain healthy and accept new requests after a client disconnects."""
147+
app = create_app()
148+
server_thread = ServerThread(app)
149+
server_thread.start()
150+
151+
try:
152+
await anyio.sleep(0.2)
153+
154+
async with httpx.AsyncClient(
155+
transport=httpx.ASGITransport(app=app),
156+
base_url="http://testserver",
157+
timeout=1.0,
158+
) as client:
159+
# First request - will timeout (simulating client disconnect)
160+
try:
161+
await client.post(
162+
"/",
163+
json={
164+
"jsonrpc": "2.0",
165+
"method": "tools/call",
166+
"id": "call-timeout",
167+
"params": {"name": "slow_tool", "arguments": {}},
168+
},
169+
headers={
170+
"Accept": "application/json, text/event-stream",
171+
"Content-Type": "application/json",
172+
},
173+
)
174+
except (httpx.ReadTimeout, httpx.ReadError):
175+
pass # Expected - client timed out
176+
177+
# Create a new client for the second request
178+
async with httpx.AsyncClient(
179+
transport=httpx.ASGITransport(app=app),
180+
base_url="http://testserver",
181+
timeout=5.0,
182+
) as client:
183+
# Second request - should succeed (server still healthy)
184+
response = await client.post(
185+
"/",
186+
json={
187+
"jsonrpc": "2.0",
188+
"method": "initialize",
189+
"id": "init-after-disconnect",
190+
"params": {
191+
"clientInfo": {"name": "test-client", "version": "1.0"},
192+
"protocolVersion": "2025-03-26",
193+
"capabilities": {},
194+
},
195+
},
196+
headers={
197+
"Accept": "application/json, text/event-stream",
198+
"Content-Type": "application/json",
199+
},
200+
)
201+
assert response.status_code == 200
202+
finally:
203+
server_thread.stop()
204+
server_thread.join(timeout=2)

0 commit comments

Comments
 (0)