Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError

from mcp import types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
Expand Down Expand Up @@ -69,6 +68,12 @@ async def sse_client(
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
# Before task_status.started() fires, the caller is blocked inside
# tg.start() and nobody reads from read_stream. Sending to the
# zero-buffer stream in that phase would deadlock, so errors must
# be raised instead. After started(), the caller has the streams
# and errors are delivered through read_stream.
started = False
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
Expand All @@ -79,27 +84,28 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):

url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
raise ValueError(
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover

if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)

task_status.started(endpoint_url)
started = True

case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
if not started:
raise RuntimeError("Received message event before endpoint event")
try:
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
logger.debug(f"Received server message: {message}")
Expand All @@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: lax no cover
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc: # pragma: lax no cover
except Exception as exc:
logger.exception("Error in sse_reader")
if not started:
raise
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
Expand Down
105 changes: 105 additions & 0 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import json
import multiprocessing
import socket
import sys
from collections.abc import AsyncGenerator, Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from urllib.parse import urlparse

# BaseExceptionGroup is builtin on 3.11+. On 3.10 it comes from the
# exceptiongroup backport, which anyio pulls in as a dependency.
if sys.version_info < (3, 11): # pragma: lax no cover
from exceptiongroup import BaseExceptionGroup

import anyio
import httpx
import pytest
Expand Down Expand Up @@ -604,6 +610,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
assert msg.message.id == 1


def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any:
"""Patch sse_client's HTTP layer to yield the given SSE event stream."""
mock_event_source = MagicMock()
mock_event_source.aiter_sse.return_value = aiter_sse
mock_event_source.response.raise_for_status = MagicMock()

mock_aconnect_sse = MagicMock()
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)

mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))

return patch.multiple(
"mcp.client.sse",
create_mcp_http_client=Mock(return_value=mock_client),
aconnect_sse=Mock(return_value=mock_aconnect_sse),
)


@pytest.mark.anyio
async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None:
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447

When the server sends an endpoint URL with a different origin than the
connection URL, sse_client must raise promptly instead of deadlocking.
Before the fix, the ValueError was caught and sent to a zero-buffer stream
with no reader, hanging forever.
"""

async def events() -> AsyncGenerator[ServerSentEvent, None]:
yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc")
await anyio.sleep_forever() # pragma: no cover

with _mock_sse_connection(events()), anyio.fail_after(5):
with pytest.raises(BaseExceptionGroup) as exc_info:
async with sse_client("http://test/sse"): # pragma: no branch
pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover
assert exc_info.group_contains(ValueError, match="Endpoint origin does not match")


@pytest.mark.anyio
async def test_sse_client_raises_on_error_before_endpoint() -> None:
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447

Any exception raised while waiting for the endpoint event must propagate
instead of deadlocking on the zero-buffer read stream.
"""

async def events() -> AsyncGenerator[ServerSentEvent, None]:
raise ConnectionError("connection reset by peer")
yield # pragma: no cover

with _mock_sse_connection(events()), anyio.fail_after(5):
with pytest.raises(BaseExceptionGroup) as exc_info:
async with sse_client("http://test/sse"): # pragma: no branch
pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover
assert exc_info.group_contains(ConnectionError, match="connection reset")


@pytest.mark.anyio
async def test_sse_client_raises_on_message_before_endpoint() -> None:
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447

If the server sends a message event before the endpoint event (protocol
violation), sse_client must raise rather than deadlock trying to send the
message to a stream nobody is reading yet.
"""

async def events() -> AsyncGenerator[ServerSentEvent, None]:
yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}')
await anyio.sleep_forever() # pragma: no cover

with _mock_sse_connection(events()), anyio.fail_after(5):
with pytest.raises(BaseExceptionGroup) as exc_info:
async with sse_client("http://test/sse"): # pragma: no branch
pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover
assert exc_info.group_contains(RuntimeError, match="before endpoint event")


@pytest.mark.anyio
async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None:
"""After the endpoint is received, errors in sse_reader are delivered on the
read stream so the session can handle them, rather than crashing the task group.
"""

async def events() -> AsyncGenerator[ServerSentEvent, None]:
yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc")
raise ConnectionError("mid-stream failure")

with _mock_sse_connection(events()), anyio.fail_after(5):
async with sse_client("http://test/sse") as (read_stream, _):
received = await read_stream.receive()
assert isinstance(received, ConnectionError)
assert "mid-stream failure" in str(received)


@pytest.mark.anyio
async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None:
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227
Expand Down
Loading