Skip to content

Commit 1fd557a

Browse files
authored
Add type checker to examples/client (#1837)
1 parent 3863f20 commit 1fd557a

File tree

7 files changed

+65
-37
lines changed
  • examples/clients
    • conformance-auth-client/mcp_conformance_auth_client
    • simple-auth-client/mcp_simple_auth_client
    • simple-chatbot/mcp_simple_chatbot
    • simple-task-client/mcp_simple_task_client
    • simple-task-interactive-client/mcp_simple_task_interactive_client
    • sse-polling-client/mcp_sse_polling_client

7 files changed

+65
-37
lines changed

examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import logging
3030
import os
3131
import sys
32-
from urllib.parse import ParseResult, parse_qs, urlparse
32+
from typing import Any, cast
33+
from urllib.parse import parse_qs, urlparse
3334

3435
import httpx
3536
from mcp import ClientSession
@@ -39,12 +40,12 @@
3940
PrivateKeyJWTOAuthProvider,
4041
SignedJWTParameters,
4142
)
42-
from mcp.client.streamable_http import streamablehttp_client
43+
from mcp.client.streamable_http import streamable_http_client
4344
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
4445
from pydantic import AnyUrl
4546

4647

47-
def get_conformance_context() -> dict:
48+
def get_conformance_context() -> dict[str, Any]:
4849
"""Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable."""
4950
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
5051
if not context_json:
@@ -116,9 +117,9 @@ async def handle_redirect(self, authorization_url: str) -> None:
116117

117118
# Check for redirect response
118119
if response.status_code in (301, 302, 303, 307, 308):
119-
location = response.headers.get("location")
120+
location = cast(str, response.headers.get("location"))
120121
if location:
121-
redirect_url: ParseResult = urlparse(location)
122+
redirect_url = urlparse(location)
122123
query_params: dict[str, list[str]] = parse_qs(redirect_url.query)
123124

124125
if "code" in query_params:
@@ -259,12 +260,8 @@ async def run_client_credentials_basic_client(server_url: str) -> None:
259260
async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
260261
"""Common session logic for all OAuth flows."""
261262
# Connect using streamable HTTP transport with OAuth
262-
async with streamablehttp_client(
263-
url=server_url,
264-
auth=oauth_auth,
265-
timeout=30.0,
266-
sse_read_timeout=60.0,
267-
) as (read_stream, write_stream, _):
263+
client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
264+
async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream, _):
268265
async with ClientSession(read_stream, write_stream) as session:
269266
# Initialize the session
270267
await session.initialize()

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,26 @@
66
77
"""
88

9+
from __future__ import annotations as _annotations
10+
911
import asyncio
1012
import os
13+
import socketserver
1114
import threading
1215
import time
1316
import webbrowser
1417
from http.server import BaseHTTPRequestHandler, HTTPServer
15-
from typing import Any
18+
from typing import Any, Callable
1619
from urllib.parse import parse_qs, urlparse
1720

1821
import httpx
22+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1923
from mcp.client.auth import OAuthClientProvider, TokenStorage
2024
from mcp.client.session import ClientSession
2125
from mcp.client.sse import sse_client
2226
from mcp.client.streamable_http import streamable_http_client
2327
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
28+
from mcp.shared.message import SessionMessage
2429

2530

2631
class InMemoryTokenStorage(TokenStorage):
@@ -46,7 +51,13 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4651
class CallbackHandler(BaseHTTPRequestHandler):
4752
"""Simple HTTP handler to capture OAuth callback."""
4853

49-
def __init__(self, request, client_address, server, callback_data):
54+
def __init__(
55+
self,
56+
request: Any,
57+
client_address: tuple[str, int],
58+
server: socketserver.BaseServer,
59+
callback_data: dict[str, Any],
60+
):
5061
"""Initialize with callback data storage."""
5162
self.callback_data = callback_data
5263
super().__init__(request, client_address, server)
@@ -91,15 +102,14 @@ def do_GET(self):
91102
self.send_response(404)
92103
self.end_headers()
93104

94-
def log_message(self, format, *args):
105+
def log_message(self, format: str, *args: Any):
95106
"""Suppress default logging."""
96-
pass
97107

98108

99109
class CallbackServer:
100110
"""Simple server to handle OAuth callbacks."""
101111

102-
def __init__(self, port=3000):
112+
def __init__(self, port: int = 3000):
103113
self.port = port
104114
self.server = None
105115
self.thread = None
@@ -110,7 +120,12 @@ def _create_handler_with_data(self):
110120
callback_data = self.callback_data
111121

112122
class DataCallbackHandler(CallbackHandler):
113-
def __init__(self, request, client_address, server):
123+
def __init__(
124+
self,
125+
request: BaseHTTPRequestHandler,
126+
client_address: tuple[str, int],
127+
server: socketserver.BaseServer,
128+
):
114129
super().__init__(request, client_address, server, callback_data)
115130

116131
return DataCallbackHandler
@@ -131,7 +146,7 @@ def stop(self):
131146
if self.thread:
132147
self.thread.join(timeout=1)
133148

134-
def wait_for_callback(self, timeout=300):
149+
def wait_for_callback(self, timeout: int = 300):
135150
"""Wait for OAuth callback with timeout."""
136151
start_time = time.time()
137152
while time.time() - start_time < timeout:
@@ -225,7 +240,12 @@ async def _default_redirect_handler(authorization_url: str) -> None:
225240

226241
traceback.print_exc()
227242

228-
async def _run_session(self, read_stream, write_stream, get_session_id):
243+
async def _run_session(
244+
self,
245+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
246+
write_stream: MemoryObjectSendStream[SessionMessage],
247+
get_session_id: Callable[[], str | None] | None = None,
248+
):
229249
"""Run the MCP session with the given streams."""
230250
print("🤝 Initializing MCP session...")
231251
async with ClientSession(read_stream, write_stream) as session:
@@ -314,7 +334,7 @@ async def interactive_loop(self):
314334
continue
315335

316336
# Parse arguments (simple JSON-like format)
317-
arguments = {}
337+
arguments: dict[str, Any] = {}
318338
if len(parts) > 2:
319339
import json
320340

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import json
35
import logging
@@ -93,7 +95,7 @@ async def initialize(self) -> None:
9395
await self.cleanup()
9496
raise
9597

96-
async def list_tools(self) -> list[Any]:
98+
async def list_tools(self) -> list[Tool]:
9799
"""List available tools from the server.
98100
99101
Returns:
@@ -106,10 +108,10 @@ async def list_tools(self) -> list[Any]:
106108
raise RuntimeError(f"Server {self.name} not initialized")
107109

108110
tools_response = await self.session.list_tools()
109-
tools = []
111+
tools: list[Tool] = []
110112

111113
for item in tools_response:
112-
if isinstance(item, tuple) and item[0] == "tools":
114+
if item[0] == "tools":
113115
tools.extend(Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1])
114116

115117
return tools
@@ -189,7 +191,7 @@ def format_for_llm(self) -> str:
189191
Returns:
190192
A formatted string describing the tool.
191193
"""
192-
args_desc = []
194+
args_desc: list[str] = []
193195
if "properties" in self.input_schema:
194196
for param_name, param_info in self.input_schema["properties"].items():
195197
arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}"
@@ -311,9 +313,9 @@ def _clean_json_string(json_string: str) -> str:
311313
result = await server.execute_tool(tool_call["tool"], tool_call["arguments"])
312314

313315
if isinstance(result, dict) and "progress" in result:
314-
progress = result["progress"]
315-
total = result["total"]
316-
percentage = (progress / total) * 100
316+
progress = result["progress"] # type: ignore
317+
total = result["total"] # type: ignore
318+
percentage = (progress / total) * 100 # type: ignore
317319
logging.info(f"Progress: {progress}/{total} ({percentage:.1f}%)")
318320

319321
return f"Tool execution result: {result}"
@@ -338,7 +340,7 @@ async def start(self) -> None:
338340
await self.cleanup_servers()
339341
return
340342

341-
all_tools = []
343+
all_tools: list[Tool] = []
342344
for server in self.servers:
343345
tools = await server.list_tools()
344346
all_tools.extend(tools)

examples/clients/simple-task-client/mcp_simple_task_client/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import click
66
from mcp import ClientSession
7-
from mcp.client.streamable_http import streamablehttp_client
7+
from mcp.client.streamable_http import streamable_http_client
88
from mcp.types import CallToolResult, TextContent
99

1010

1111
async def run(url: str) -> None:
12-
async with streamablehttp_client(url) as (read, write, _):
12+
async with streamable_http_client(url) as (read, write, _):
1313
async with ClientSession(read, write) as session:
1414
await session.initialize()
1515

@@ -28,12 +28,13 @@ async def run(url: str) -> None:
2828
task_id = result.task.taskId
2929
print(f"Task created: {task_id}")
3030

31+
status = None
3132
# Poll until done (respects server's pollInterval hint)
3233
async for status in session.experimental.poll_task(task_id):
3334
print(f" Status: {status.status} - {status.statusMessage or ''}")
3435

3536
# Check final status
36-
if status.status != "completed":
37+
if status and status.status != "completed":
3738
print(f"Task ended with status: {status.status}")
3839
return
3940

examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import click
1313
from mcp import ClientSession
14-
from mcp.client.streamable_http import streamablehttp_client
14+
from mcp.client.streamable_http import streamable_http_client
1515
from mcp.shared.context import RequestContext
1616
from mcp.types import (
1717
CallToolResult,
@@ -73,7 +73,7 @@ def get_text(result: CallToolResult) -> str:
7373

7474

7575
async def run(url: str) -> None:
76-
async with streamablehttp_client(url) as (read, write, _):
76+
async with streamable_http_client(url) as (read, write, _):
7777
async with ClientSession(
7878
read,
7979
write,

examples/clients/sse-polling-client/mcp_sse_polling_client/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import click
2222
from mcp import ClientSession
23-
from mcp.client.streamable_http import streamablehttp_client
23+
from mcp.client.streamable_http import streamable_http_client
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -34,7 +34,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
3434
print(f"Processing {items} items with checkpoints every {checkpoint_every}")
3535
print(f"{'=' * 60}\n")
3636

37-
async with streamablehttp_client(url) as (read_stream, write_stream, _):
37+
async with streamable_http_client(url) as (read_stream, write_stream, _):
3838
async with ClientSession(read_stream, write_stream) as session:
3939
# Initialize the connection
4040
print("Initializing connection...")

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ packages = ["src/mcp"]
9393

9494
[tool.pyright]
9595
typeCheckingMode = "strict"
96-
include = ["src/mcp", "tests", "examples/servers", "examples/snippets"]
96+
include = [
97+
"src/mcp",
98+
"tests",
99+
"examples/servers",
100+
"examples/snippets",
101+
"examples/clients",
102+
]
97103
venvPath = "."
98104
venv = ".venv"
99105
# The FastAPI style of using decorators in tests gives a `reportUnusedFunction` error.
@@ -102,7 +108,9 @@ venv = ".venv"
102108
# those private functions instead of testing the private functions directly. It makes it easier to maintain the code source
103109
# and refactor code that is not public.
104110
executionEnvironments = [
105-
{ root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false },
111+
{ root = "tests", extraPaths = [
112+
".",
113+
], reportUnusedFunction = false, reportPrivateUsage = false },
106114
{ root = "examples/servers", reportUnusedFunction = false },
107115
]
108116

0 commit comments

Comments
 (0)