Skip to content

Commit db9da1c

Browse files
author
PR-Contributor
committed
feat: Add protocol_version parameter to ClientSession
This change allows users to specify a custom protocol version when initializing a ClientSession, instead of always using LATEST_PROTOCOL_VERSION. This is needed for connecting to MCP servers that require a specific protocol version (e.g., Snowflake's managed MCP server requires 2025-06-18). Changes: - Added protocol_version parameter to ClientSession.__init__() - Stored protocol_version as instance variable _protocol_version - Updated initialize() to use self._protocol_version instead of hardcoded constant - Added tests for both custom and default protocol version scenarios Backwards compatible: defaults to LATEST_PROTOCOL_VERSION when not specified. Fixes #2307
1 parent 92c693b commit db9da1c

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

src/mcp/client/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
logging_callback: LoggingFnT | None = None,
119119
message_handler: MessageHandlerFnT | None = None,
120120
client_info: types.Implementation | None = None,
121+
protocol_version: str | None = None,
121122
*,
122123
sampling_capabilities: types.SamplingCapability | None = None,
123124
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
@@ -133,6 +134,7 @@ def __init__(
133134
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
134135
self._initialize_result: types.InitializeResult | None = None
135136
self._experimental_features: ExperimentalClientFeatures | None = None
137+
self._protocol_version = protocol_version or types.LATEST_PROTOCOL_VERSION
136138

137139
# Experimental: Task handlers (use defaults if not provided)
138140
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
@@ -168,7 +170,7 @@ async def initialize(self) -> types.InitializeResult:
168170
result = await self.send_request(
169171
types.InitializeRequest(
170172
params=types.InitializeRequestParams(
171-
protocol_version=types.LATEST_PROTOCOL_VERSION,
173+
protocol_version=self._protocol_version,
172174
capabilities=types.ClientCapabilities(
173175
sampling=sampling,
174176
elicitation=elicitation,

tests/client/test_session.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,125 @@ async def mock_server():
606606
assert result.protocol_version == LATEST_PROTOCOL_VERSION
607607

608608

609+
@pytest.mark.anyio
610+
async def test_client_session_custom_protocol_version():
611+
"""Test that custom protocol_version is sent during initialization.
612+
613+
This allows connecting to servers that require a specific protocol version,
614+
such as Snowflake's managed MCP server which requires "2025-06-18".
615+
See: https://github.com/modelcontextprotocol/python-sdk/issues/2307
616+
"""
617+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
618+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
619+
620+
custom_protocol_version = "2025-06-18"
621+
received_protocol_version = None
622+
623+
async def mock_server():
624+
nonlocal received_protocol_version
625+
626+
session_message = await client_to_server_receive.receive()
627+
jsonrpc_request = session_message.message
628+
assert isinstance(jsonrpc_request, JSONRPCRequest)
629+
request = client_request_adapter.validate_python(
630+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
631+
)
632+
assert isinstance(request, InitializeRequest)
633+
received_protocol_version = request.params.protocol_version
634+
635+
result = InitializeResult(
636+
protocol_version=custom_protocol_version,
637+
capabilities=ServerCapabilities(),
638+
server_info=Implementation(name="mock-server", version="0.1.0"),
639+
)
640+
641+
async with server_to_client_send:
642+
await server_to_client_send.send(
643+
SessionMessage(
644+
JSONRPCResponse(
645+
jsonrpc="2.0",
646+
id=jsonrpc_request.id,
647+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
648+
)
649+
)
650+
)
651+
# Receive initialized notification
652+
await client_to_server_receive.receive()
653+
654+
async with (
655+
ClientSession(
656+
server_to_client_receive,
657+
client_to_server_send,
658+
protocol_version=custom_protocol_version,
659+
) as session,
660+
anyio.create_task_group() as tg,
661+
client_to_server_send,
662+
client_to_server_receive,
663+
server_to_client_send,
664+
server_to_client_receive,
665+
):
666+
tg.start_soon(mock_server)
667+
result = await session.initialize()
668+
669+
# Assert that the custom protocol version was sent and received
670+
assert received_protocol_version == custom_protocol_version
671+
assert result.protocol_version == custom_protocol_version
672+
673+
674+
@pytest.mark.anyio
675+
async def test_client_session_default_protocol_version():
676+
"""Test that LATEST_PROTOCOL_VERSION is used when protocol_version is not specified."""
677+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
678+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
679+
680+
received_protocol_version = None
681+
682+
async def mock_server():
683+
nonlocal received_protocol_version
684+
685+
session_message = await client_to_server_receive.receive()
686+
jsonrpc_request = session_message.message
687+
assert isinstance(jsonrpc_request, JSONRPCRequest)
688+
request = client_request_adapter.validate_python(
689+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
690+
)
691+
assert isinstance(request, InitializeRequest)
692+
received_protocol_version = request.params.protocol_version
693+
694+
result = InitializeResult(
695+
protocol_version=LATEST_PROTOCOL_VERSION,
696+
capabilities=ServerCapabilities(),
697+
server_info=Implementation(name="mock-server", version="0.1.0"),
698+
)
699+
700+
async with server_to_client_send:
701+
await server_to_client_send.send(
702+
SessionMessage(
703+
JSONRPCResponse(
704+
jsonrpc="2.0",
705+
id=jsonrpc_request.id,
706+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
707+
)
708+
)
709+
)
710+
# Receive initialized notification
711+
await client_to_server_receive.receive()
712+
713+
async with (
714+
ClientSession(server_to_client_receive, client_to_server_send) as session,
715+
anyio.create_task_group() as tg,
716+
client_to_server_send,
717+
client_to_server_receive,
718+
server_to_client_send,
719+
server_to_client_receive,
720+
):
721+
tg.start_soon(mock_server)
722+
await session.initialize()
723+
724+
# Assert that the default (latest) protocol version was sent
725+
assert received_protocol_version == LATEST_PROTOCOL_VERSION
726+
727+
609728
@pytest.mark.anyio
610729
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
611730
async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):

0 commit comments

Comments
 (0)