Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
except ImportError as e:

if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
Expand All @@ -49,7 +48,6 @@

logger = logging.getLogger('google_adk.' + __name__)


class StdioConnectionParams(BaseModel):
"""Parameters for the MCP Stdio connection.

Expand All @@ -58,16 +56,17 @@ class StdioConnectionParams(BaseModel):
timeout: Timeout in seconds for establishing the connection to the MCP
stdio server.
"""

server_params: StdioServerParameters
timeout: float = 5.0

class Config:
arbitrary_types_allowed = True

class SseConnectionParams(BaseModel):
"""Parameters for the MCP SSE connection.

See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
[https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py)

Attributes:
url: URL for the MCP SSE server.
Expand All @@ -77,18 +76,16 @@ class SseConnectionParams(BaseModel):
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
server.
"""

url: str
headers: dict[str, Any] | None = None
timeout: float = 5.0
sse_read_timeout: float = 60 * 5.0


class StreamableHTTPConnectionParams(BaseModel):
"""Parameters for the MCP Streamable HTTP connection.

See MCP Streamable HTTP Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py
[https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py)

Attributes:
url: URL for the MCP Streamable HTTP server.
Expand All @@ -100,41 +97,45 @@ class StreamableHTTPConnectionParams(BaseModel):
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
when the connection is closed.
"""

url: str
headers: dict[str, Any] | None = None
timeout: float = 5.0
sse_read_timeout: float = 60 * 5.0
terminate_on_close: bool = True


def retry_on_closed_resource(func):
"""Decorator to automatically retry action when MCP session is closed.

CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations.
Do NOT use with tool calls that create, update, or delete resources as
retrying can cause duplicate operations or data corruption.

Only use with read-only, idempotent operations like list_tools,
list_resources, or read_resource.

Do NOT apply to generic tool execution methods like _run_async_impl.

When MCP session was closed, the decorator will automatically retry the
action once. The create_session method will handle creating a new session
if the old one was disconnected.

Args:
func: The function to decorate.
func: The function to decorate. Must be idempotent and safe to retry.

Returns:
The decorated function.
"""

@functools.wraps(func) # Preserves original function metadata
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
# Simply retry the function - create_session will handle
# detecting and replacing disconnected sessions
logger.info('Retrying %s due to closed resource', func.__name__)
logger.info('Retrying %s due to closed/broken resource', func.__name__)
return await func(self, *args, **kwargs)

return wrapper


class MCPSessionManager:
"""Manages MCP client sessions.

Expand Down Expand Up @@ -176,11 +177,10 @@ def __init__(
)
else:
self._connection_params = connection_params
self._errlog = errlog

self._errlog = errlog
# Session pool: maps session keys to (session, exit_stack) tuples
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}

# Lock to prevent race conditions in session creation
self._session_lock = asyncio.Lock()

Expand Down Expand Up @@ -292,6 +292,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
' StdioServerParameters or SseServerParams, but got'
f' {self._connection_params}'
)

return client

async def create_session(
Expand All @@ -313,7 +314,6 @@ async def create_session(
"""
# Merge headers once at the beginning
merged_headers = self._merge_headers(headers)

# Generate session key using merged headers
session_key = self._generate_session_key(merged_headers)

Expand All @@ -322,7 +322,6 @@ async def create_session(
# Check if we have an existing session
if session_key in self._sessions:
session, exit_stack = self._sessions[session_key]

# Check if the existing session is still connected
if not self._is_session_disconnected(session):
# Session is still good, return it
Expand All @@ -339,11 +338,10 @@ async def create_session(

# Create a new session (either first time or replacing disconnected one)
exit_stack = AsyncExitStack()

try:
client = self._create_client(merged_headers)

transports = await exit_stack.enter_async_context(client)

# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
if isinstance(self._connection_params, StdioConnectionParams):
Expand All @@ -359,8 +357,8 @@ async def create_session(
session = await exit_stack.enter_async_context(
ClientSession(*transports[:2])
)
await session.initialize()

await session.initialize()
# Store session and exit stack in the pool
self._sessions[session_key] = (session, exit_stack)
logger.debug('Created new session: %s', session_key)
Expand All @@ -369,7 +367,10 @@ async def create_session(
except Exception:
# If session creation fails, clean up the exit stack
if exit_stack:
await exit_stack.aclose()
try:
await exit_stack.aclose()
except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e:
logger.warning('Error during exit stack cleanup: %s', e)
raise

async def close(self):
Expand All @@ -389,7 +390,5 @@ async def close(self):
finally:
del self._sessions[session_key]


SseServerParams = SseConnectionParams

StreamableHTTPServerParams = StreamableHTTPConnectionParams