Skip to content

Commit dde19fb

Browse files
refactor: use Client class in tests
Refactored tests to use the ergonomic Client class instead of the verbose InMemoryTransport + ClientSession pattern: - tests/client/transports/test_memory.py: 3 tests - tests/shared/test_session.py: 2 tests - tests/server/test_cancel_handling.py: 1 test - tests/shared/test_progress_notifications.py: 1 test - tests/client/test_list_methods_cursor.py: 6 tests (including stream_spy test) The stream_spy fixture continues to work with Client since it patches the underlying memory streams used by InMemoryTransport. Github-Issue: #1891
1 parent d41d0c0 commit dde19fb

File tree

5 files changed

+163
-227
lines changed

5 files changed

+163
-227
lines changed

tests/client/test_list_methods_cursor.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import pytest
44

55
import mcp.types as types
6-
from mcp.client._memory import InMemoryTransport
7-
from mcp.client.session import ClientSession
6+
from mcp import Client
87
from mcp.server import Server
98
from mcp.server.fastmcp import FastMCP
109
from mcp.types import ListToolsRequest, ListToolsResult
@@ -66,49 +65,43 @@ async def test_list_methods_params_parameter(
6665
6766
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
6867
"""
69-
transport = InMemoryTransport(full_featured_server)
70-
async with transport.connect() as (read_stream, write_stream):
71-
async with ClientSession(read_stream, write_stream) as session:
72-
await session.initialize()
73-
spies = stream_spy()
74-
75-
# Test without params (omitted)
76-
method = getattr(session, method_name)
77-
_ = await method()
78-
requests = spies.get_client_requests(method=request_method)
79-
assert len(requests) == 1
80-
assert requests[0].params is None
81-
82-
spies.clear()
83-
84-
# Test with params containing cursor
85-
_ = await method(params=types.PaginatedRequestParams(cursor="from_params"))
86-
requests = spies.get_client_requests(method=request_method)
87-
assert len(requests) == 1
88-
assert requests[0].params is not None
89-
assert requests[0].params["cursor"] == "from_params"
90-
91-
spies.clear()
92-
93-
# Test with empty params
94-
_ = await method(params=types.PaginatedRequestParams())
95-
requests = spies.get_client_requests(method=request_method)
96-
assert len(requests) == 1
97-
# Empty params means no cursor
98-
assert requests[0].params is None or "cursor" not in requests[0].params
68+
async with Client(full_featured_server) as client:
69+
spies = stream_spy()
70+
71+
# Test without params (omitted)
72+
method = getattr(client, method_name)
73+
_ = await method()
74+
requests = spies.get_client_requests(method=request_method)
75+
assert len(requests) == 1
76+
assert requests[0].params is None
77+
78+
spies.clear()
79+
80+
# Test with params containing cursor
81+
_ = await method(params=types.PaginatedRequestParams(cursor="from_params"))
82+
requests = spies.get_client_requests(method=request_method)
83+
assert len(requests) == 1
84+
assert requests[0].params is not None
85+
assert requests[0].params["cursor"] == "from_params"
86+
87+
spies.clear()
88+
89+
# Test with empty params
90+
_ = await method(params=types.PaginatedRequestParams())
91+
requests = spies.get_client_requests(method=request_method)
92+
assert len(requests) == 1
93+
# Empty params means no cursor
94+
assert requests[0].params is None or "cursor" not in requests[0].params
9995

10096

10197
async def test_list_tools_with_strict_server_validation(
10298
full_featured_server: FastMCP,
10399
):
104100
"""Test pagination with a server that validates request format strictly."""
105-
transport = InMemoryTransport(full_featured_server)
106-
async with transport.connect() as (read_stream, write_stream):
107-
async with ClientSession(read_stream, write_stream) as session:
108-
await session.initialize()
109-
result = await session.list_tools(params=types.PaginatedRequestParams())
110-
assert isinstance(result, ListToolsResult)
111-
assert len(result.tools) > 0
101+
async with Client(full_featured_server) as client:
102+
result = await client.list_tools(params=types.PaginatedRequestParams())
103+
assert isinstance(result, ListToolsResult)
104+
assert len(result.tools) > 0
112105

113106

114107
async def test_list_tools_with_lowlevel_server():
@@ -129,13 +122,9 @@ async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult:
129122
]
130123
)
131124

132-
transport = InMemoryTransport(server)
133-
async with transport.connect() as (read_stream, write_stream):
134-
async with ClientSession(read_stream, write_stream) as session:
135-
await session.initialize()
136-
137-
result = await session.list_tools(params=types.PaginatedRequestParams())
138-
assert result.tools[0].description == "cursor=None"
125+
async with Client(server) as client:
126+
result = await client.list_tools(params=types.PaginatedRequestParams())
127+
assert result.tools[0].description == "cursor=None"
139128

140-
result = await session.list_tools(params=types.PaginatedRequestParams(cursor="page2"))
141-
assert result.tools[0].description == "cursor=page2"
129+
result = await client.list_tools(params=types.PaginatedRequestParams(cursor="page2"))
130+
assert result.tools[0].description == "cursor=page2"

tests/client/transports/test_memory.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from mcp import Client
56
from mcp.client._memory import InMemoryTransport
67
from mcp.server import Server
78
from mcp.server.fastmcp import FastMCP
@@ -69,42 +70,26 @@ async def test_with_fastmcp(fastmcp_server: FastMCP):
6970

7071
async def test_server_is_running(fastmcp_server: FastMCP):
7172
"""Test that the server is running and responding to requests."""
72-
from mcp.client.session import ClientSession
73-
74-
transport = InMemoryTransport(fastmcp_server)
75-
async with transport.connect() as (read_stream, write_stream):
76-
async with ClientSession(read_stream, write_stream) as session:
77-
result = await session.initialize()
78-
assert result is not None
79-
assert result.server_info.name == "test"
73+
async with Client(fastmcp_server) as client:
74+
assert client.server_capabilities is not None
8075

8176

8277
async def test_list_tools(fastmcp_server: FastMCP):
8378
"""Test listing tools through the transport."""
84-
from mcp.client.session import ClientSession
85-
86-
transport = InMemoryTransport(fastmcp_server)
87-
async with transport.connect() as (read_stream, write_stream):
88-
async with ClientSession(read_stream, write_stream) as session:
89-
await session.initialize()
90-
tools_result = await session.list_tools()
91-
assert len(tools_result.tools) > 0
92-
tool_names = [t.name for t in tools_result.tools]
93-
assert "greet" in tool_names
79+
async with Client(fastmcp_server) as client:
80+
tools_result = await client.list_tools()
81+
assert len(tools_result.tools) > 0
82+
tool_names = [t.name for t in tools_result.tools]
83+
assert "greet" in tool_names
9484

9585

9686
async def test_call_tool(fastmcp_server: FastMCP):
9787
"""Test calling a tool through the transport."""
98-
from mcp.client.session import ClientSession
99-
100-
transport = InMemoryTransport(fastmcp_server)
101-
async with transport.connect() as (read_stream, write_stream):
102-
async with ClientSession(read_stream, write_stream) as session:
103-
await session.initialize()
104-
result = await session.call_tool("greet", {"name": "World"})
105-
assert result is not None
106-
assert len(result.content) > 0
107-
assert "Hello, World!" in str(result.content[0])
88+
async with Client(fastmcp_server) as client:
89+
result = await client.call_tool("greet", {"name": "World"})
90+
assert result is not None
91+
assert len(result.content) > 0
92+
assert "Hello, World!" in str(result.content[0])
10893

10994

11095
async def test_raise_exceptions(fastmcp_server: FastMCP):

tests/server/test_cancel_handling.py

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import pytest
77

88
import mcp.types as types
9-
from mcp.client._memory import InMemoryTransport
10-
from mcp.client.session import ClientSession
9+
from mcp import Client
1110
from mcp.server.lowlevel.server import Server
1211
from mcp.shared.exceptions import McpError
1312
from mcp.types import (
@@ -55,61 +54,50 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[
5554
return [types.TextContent(type="text", text=f"Call number: {call_count}")]
5655
raise ValueError(f"Unknown tool: {name}") # pragma: no cover
5756

58-
transport = InMemoryTransport(server)
59-
async with transport.connect() as (read_stream, write_stream):
60-
async with ClientSession(read_stream, write_stream) as client:
61-
await client.initialize()
62-
63-
# First request (will be cancelled)
64-
async def first_request():
65-
try:
66-
await client.send_request(
67-
ClientRequest(
68-
CallToolRequest(
69-
params=CallToolRequestParams(name="test_tool", arguments={}),
70-
)
71-
),
72-
CallToolResult,
73-
)
74-
pytest.fail("First request should have been cancelled") # pragma: no cover
75-
except McpError:
76-
pass # Expected
77-
78-
# Start first request
79-
async with anyio.create_task_group() as tg:
80-
tg.start_soon(first_request)
81-
82-
# Wait for it to start
83-
await ev_first_call.wait()
84-
85-
# Cancel it
86-
assert first_request_id is not None
87-
await client.send_notification(
88-
ClientNotification(
89-
CancelledNotification(
90-
params=CancelledNotificationParams(
91-
request_id=first_request_id,
92-
reason="Testing server recovery",
93-
),
57+
async with Client(server) as client:
58+
# First request (will be cancelled)
59+
async def first_request():
60+
try:
61+
await client.session.send_request(
62+
ClientRequest(
63+
CallToolRequest(
64+
params=CallToolRequestParams(name="test_tool", arguments={}),
9465
)
95-
)
66+
),
67+
CallToolResult,
9668
)
97-
98-
# Second request (should work normally)
99-
result = await client.send_request(
100-
ClientRequest(
101-
CallToolRequest(
102-
params=CallToolRequestParams(name="test_tool", arguments={}),
69+
pytest.fail("First request should have been cancelled") # pragma: no cover
70+
except McpError:
71+
pass # Expected
72+
73+
# Start first request
74+
async with anyio.create_task_group() as tg:
75+
tg.start_soon(first_request)
76+
77+
# Wait for it to start
78+
await ev_first_call.wait()
79+
80+
# Cancel it
81+
assert first_request_id is not None
82+
await client.session.send_notification(
83+
ClientNotification(
84+
CancelledNotification(
85+
params=CancelledNotificationParams(
86+
request_id=first_request_id,
87+
reason="Testing server recovery",
88+
),
10389
)
104-
),
105-
CallToolResult,
90+
)
10691
)
10792

108-
# Verify second request completed successfully
109-
assert len(result.content) == 1
110-
# Type narrowing for pyright
111-
content = result.content[0]
112-
assert content.type == "text"
113-
assert isinstance(content, types.TextContent)
114-
assert content.text == "Call number: 2"
115-
assert call_count == 2
93+
# Second request (should work normally)
94+
result = await client.call_tool("test_tool", {})
95+
96+
# Verify second request completed successfully
97+
assert len(result.content) == 1
98+
# Type narrowing for pyright
99+
content = result.content[0]
100+
assert content.type == "text"
101+
assert isinstance(content, types.TextContent)
102+
assert content.text == "Call number: 2"
103+
assert call_count == 2

tests/shared/test_progress_notifications.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import mcp.types as types
8-
from mcp.client._memory import InMemoryTransport
8+
from mcp import Client
99
from mcp.client.session import ClientSession
1010
from mcp.server import Server
1111
from mcp.server.lowlevel import NotificationOptions
@@ -369,30 +369,20 @@ async def handle_list_tools() -> list[types.Tool]:
369369

370370
# Test with mocked logging
371371
with patch("mcp.shared.session.logging.error", side_effect=mock_log_error):
372-
transport = InMemoryTransport(server)
373-
async with transport.connect() as (read_stream, write_stream):
374-
async with ClientSession( # pragma: no branch
375-
read_stream=read_stream, write_stream=write_stream
376-
) as session:
377-
await session.initialize()
378-
# Send a request with a failing progress callback
379-
result = await session.send_request(
380-
types.ClientRequest(
381-
types.CallToolRequest(
382-
method="tools/call",
383-
params=types.CallToolRequestParams(name="progress_tool", arguments={}),
384-
)
385-
),
386-
types.CallToolResult,
387-
progress_callback=failing_progress_callback,
388-
)
372+
async with Client(server) as client:
373+
# Call tool with a failing progress callback
374+
result = await client.call_tool(
375+
"progress_tool",
376+
arguments={},
377+
progress_callback=failing_progress_callback,
378+
)
389379

390-
# Verify the request completed successfully despite the callback failure
391-
assert len(result.content) == 1
392-
content = result.content[0]
393-
assert isinstance(content, types.TextContent)
394-
assert content.text == "progress_result"
380+
# Verify the request completed successfully despite the callback failure
381+
assert len(result.content) == 1
382+
content = result.content[0]
383+
assert isinstance(content, types.TextContent)
384+
assert content.text == "progress_result"
395385

396-
# Check that a warning was logged for the progress callback exception
397-
assert len(logged_errors) > 0
398-
assert any("Progress callback raised an exception" in warning for warning in logged_errors)
386+
# Check that a warning was logged for the progress callback exception
387+
assert len(logged_errors) > 0
388+
assert any("Progress callback raised an exception" in warning for warning in logged_errors)

0 commit comments

Comments
 (0)