Skip to content

Commit aa8f5a4

Browse files
vblagojesjrl
andauthored
Add MCPTool/MCPToolset warm_up (#2384)
* Add MCPTool/MCPToolset warm_up * Update haystack dependency * Make eager_connect False by default * Add real integration tests for warm_up * Fix one test * Remove test * Update from_dict to use eager_connect False by default * Comment update * Lint * Make warm_up single point of entry for ensuring connection * PR feedback --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent f43e1f7 commit aa8f5a4

File tree

7 files changed

+211
-79
lines changed

7 files changed

+211
-79
lines changed

integrations/mcp/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ classifiers = [
2929
]
3030
dependencies = [
3131
"mcp>=1.8.0",
32-
"haystack-ai>=2.18.0",
32+
"haystack-ai>=2.19.0",
3333
"exceptiongroup", # Backport of ExceptionGroup for Python < 3.11
3434
"httpx" # HTTP client library used for SSE connections
3535
]

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ def __init__(
814814
description: str | None = None,
815815
connection_timeout: int = 30,
816816
invocation_timeout: int = 30,
817+
eager_connect: bool = False,
817818
):
818819
"""
819820
Initialize the MCP tool.
@@ -823,6 +824,9 @@ def __init__(
823824
:param description: Custom description (if None, server description will be used)
824825
:param connection_timeout: Timeout in seconds for server connection
825826
:param invocation_timeout: Default timeout in seconds for tool invocations
827+
:param eager_connect: If True, connect to server during initialization.
828+
If False (default), defer connection until warm_up or first tool use,
829+
whichever comes first.
826830
:raises MCPConnectionError: If connection to the server fails
827831
:raises MCPToolNotFoundError: If no tools are available or the requested tool is not found
828832
:raises TimeoutError: If connection times out
@@ -832,39 +836,27 @@ def __init__(
832836
self._server_info = server_info
833837
self._connection_timeout = connection_timeout
834838
self._invocation_timeout = invocation_timeout
839+
self._eager_connect = eager_connect
840+
self._client: MCPClient | None = None
841+
self._worker: _MCPClientSessionManager | None = None
842+
self._lock = threading.RLock()
843+
844+
# don't connect now; initialize permissively
845+
if not eager_connect:
846+
# Permissive placeholder JSON Schema so the Tool is valid
847+
# without discovering the remote schema during validation.
848+
# Tool parameters/schema will be replaced with the correct schema (from the MCP server) on first use.
849+
params = {"type": "object", "properties": {}, "additionalProperties": True}
850+
super().__init__(name=name, description=description or "", parameters=params, function=self._invoke_tool)
851+
return
835852

836853
logger.debug(f"TOOL: Initializing MCPTool '{name}'")
837854

838855
try:
839-
# Create client and spin up a long-lived worker that keeps the
840-
# connect/close lifecycle inside one coroutine.
841-
self._client = server_info.create_client()
842-
logger.debug(f"TOOL: Created client for MCPTool '{name}'")
843-
844-
# The worker starts immediately and blocks here until the connection
845-
# is established (or fails), returning the tool list.
846-
self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout)
847-
848-
tools = self._worker.tools()
849-
# Handle no tools case
850-
if not tools:
851-
logger.debug(f"TOOL: No tools found for '{name}'")
852-
message = "No tools available on server"
853-
raise MCPToolNotFoundError(message, tool_name=name)
854-
855-
# Find the specified tool
856-
tool_dict = {t.name: t for t in tools}
857-
logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}")
858-
859-
tool_info: types.Tool | None = tool_dict.get(name)
860-
861-
if not tool_info:
862-
available = list(tool_dict.keys())
863-
logger.debug(f"TOOL: Tool '{name}' not found in available tools")
864-
message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}"
865-
raise MCPToolNotFoundError(message, tool_name=name, available_tools=available)
866-
856+
logger.debug(f"TOOL: Connecting to MCP server for '{name}'")
857+
tool_info = self._connect_and_initialize(name)
867858
logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")
859+
868860
# Initialize the parent class
869861
super().__init__(
870862
name=name,
@@ -897,6 +889,36 @@ def __init__(
897889
message = f"Failed to initialize MCPTool '{name}': {error_message}"
898890
raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e
899891

892+
def _connect_and_initialize(self, tool_name: str) -> types.Tool:
893+
"""
894+
Connect to the MCP server and retrieve the tool schema.
895+
896+
:param tool_name: Name of the tool to look for
897+
:returns: The tool schema for this tool
898+
:raises MCPToolNotFoundError: If the tool is not found on the server
899+
"""
900+
client = self._server_info.create_client()
901+
worker = _MCPClientSessionManager(client, timeout=self._connection_timeout)
902+
tools = worker.tools()
903+
904+
# Handle no tools case
905+
if not tools:
906+
message = "No tools available on server"
907+
raise MCPToolNotFoundError(message, tool_name=tool_name)
908+
909+
# Find the specified tool
910+
tool = next((t for t in tools if t.name == tool_name), None)
911+
if tool is None:
912+
available = [t.name for t in tools]
913+
msg = f"Tool '{tool_name}' not found on server. Available tools: {', '.join(available)}"
914+
raise MCPToolNotFoundError(msg, tool_name=tool_name, available_tools=available)
915+
916+
# Publish connection
917+
self._client = client
918+
self._worker = worker
919+
920+
return tool
921+
900922
def _invoke_tool(self, **kwargs: Any) -> str:
901923
"""
902924
Synchronous tool invocation.
@@ -906,12 +928,13 @@ def _invoke_tool(self, **kwargs: Any) -> str:
906928
"""
907929
logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}")
908930
try:
931+
# Connect on first use if eager_connect is turned off
932+
self.warm_up()
909933

910934
async def invoke():
911935
logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'")
912-
result = await asyncio.wait_for(
913-
self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout
914-
)
936+
client = cast(MCPClient, self._client)
937+
result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
915938
logger.debug(f"TOOL: Invoke successful for '{self.name}'")
916939
return result
917940

@@ -939,7 +962,9 @@ async def ainvoke(self, **kwargs: Any) -> str:
939962
:raises TimeoutError: If the operation times out
940963
"""
941964
try:
942-
return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
965+
self.warm_up()
966+
client = cast(MCPClient, self._client)
967+
return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
943968
except asyncio.TimeoutError as e:
944969
message = f"Tool invocation timed out after {self._invocation_timeout} seconds"
945970
raise TimeoutError(message) from e
@@ -949,6 +974,14 @@ async def ainvoke(self, **kwargs: Any) -> str:
949974
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
950975
raise MCPInvocationError(message, self.name, kwargs) from e
951976

977+
def warm_up(self) -> None:
978+
"""Connect and fetch the tool schema if eager_connect is turned off."""
979+
with self._lock:
980+
if self._client is not None:
981+
return
982+
tool = self._connect_and_initialize(self.name)
983+
self.parameters = tool.inputSchema
984+
952985
def to_dict(self) -> dict[str, Any]:
953986
"""
954987
Serializes the MCPTool to a dictionary.
@@ -966,6 +999,7 @@ def to_dict(self) -> dict[str, Any]:
966999
"server_info": self._server_info.to_dict(),
9671000
"connection_timeout": self._connection_timeout,
9681001
"invocation_timeout": self._invocation_timeout,
1002+
"eager_connect": self._eager_connect,
9691003
}
9701004
return {
9711005
"type": generate_qualified_class_name(type(self)),
@@ -998,6 +1032,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
9981032
# Handle backward compatibility for timeout parameters
9991033
connection_timeout = inner_data.get("connection_timeout", 30)
10001034
invocation_timeout = inner_data.get("invocation_timeout", 30)
1035+
eager_connect = inner_data.get("eager_connect", False) # because False is the default
10011036

10021037
# Create a new MCPTool instance with the deserialized parameters
10031038
# This will establish a new connection to the MCP server
@@ -1007,6 +1042,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
10071042
server_info=server_info,
10081043
connection_timeout=connection_timeout,
10091044
invocation_timeout=invocation_timeout,
1045+
eager_connect=eager_connect,
10101046
)
10111047

10121048
def close(self):

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
tool_names: list[str] | None = None,
121121
connection_timeout: float = 30.0,
122122
invocation_timeout: float = 30.0,
123+
eager_connect: bool = False,
123124
):
124125
"""
125126
Initialize the MCP toolset.
@@ -129,15 +130,48 @@ def __init__(
129130
matching names will be added to the toolset.
130131
:param connection_timeout: Timeout in seconds for server connection
131132
:param invocation_timeout: Default timeout in seconds for tool invocations
133+
:param eager_connect: If True, connect to server and load tools during initialization.
134+
If False (default), defer connection to warm_up.
132135
:raises MCPToolNotFoundError: If any of the specified tool names are not found on the server
133136
"""
134137
# Store configuration
135138
self.server_info = server_info
136139
self.tool_names = tool_names
137140
self.connection_timeout = connection_timeout
138141
self.invocation_timeout = invocation_timeout
142+
self.eager_connect = eager_connect
143+
self._warmup_called = False
144+
145+
if not eager_connect:
146+
# Do not connect during validation; expose a toolset with one fake tool to pass validation
147+
placeholder_tool = Tool(
148+
name=f"mcp_not_connected_placeholder_{id(self)}",
149+
description="Placeholder tool initialised when eager_connect is turned off",
150+
parameters={"type": "object", "properties": {}, "additionalProperties": True},
151+
function=lambda: None,
152+
)
153+
super().__init__(tools=[placeholder_tool])
154+
else:
155+
tools = self._connect_and_load_tools()
156+
super().__init__(tools=tools)
157+
self._warmup_called = True
158+
159+
def warm_up(self) -> None:
160+
"""Connect and load tools when eager_connect is turned off.
161+
162+
This method is automatically called by ``ToolInvoker.warm_up()`` and ``Pipeline.warm_up()``.
163+
You can also call it directly before using the toolset to ensure all tool schemas
164+
are available without performing a real invocation.
165+
"""
166+
if self._warmup_called:
167+
return
168+
169+
# connect and load tools never adds duplicate tools, set the tools attribute directly
170+
self.tools = self._connect_and_load_tools()
171+
self._warmup_called = True
139172

140-
# Connect and load tools
173+
def _connect_and_load_tools(self) -> list[Tool]:
174+
"""Connect and load tools."""
141175
try:
142176
# Create the client and spin up a worker so open/close happen in the
143177
# same coroutine, avoiding AnyIO cancel-scope issues.
@@ -195,9 +229,7 @@ def invoke_tool(**kwargs: Any) -> Any:
195229
)
196230
haystack_tools.append(tool)
197231

198-
# Initialize parent class with complete tools list
199-
super().__init__(tools=haystack_tools)
200-
232+
return haystack_tools
201233
except Exception as e:
202234
# We need to close because we could connect properly, retrieve tools yet
203235
# fail because of an MCPToolNotFoundError
@@ -273,6 +305,7 @@ def to_dict(self) -> dict[str, Any]:
273305
"tool_names": self.tool_names,
274306
"connection_timeout": self.connection_timeout,
275307
"invocation_timeout": self.invocation_timeout,
308+
"eager_connect": self.eager_connect,
276309
},
277310
}
278311

@@ -297,6 +330,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset":
297330
tool_names=inner_data.get("tool_names"),
298331
connection_timeout=inner_data.get("connection_timeout", 30.0),
299332
invocation_timeout=inner_data.get("invocation_timeout", 30.0),
333+
eager_connect=inner_data.get("eager_connect", True),
300334
)
301335

302336
def close(self):

integrations/mcp/tests/test_mcp_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_mcp_tool_error_handling_integration(self):
234234
# Use a non-existent server address to force a connection error
235235
server_info = SSEServerInfo(base_url="http://localhost:9999", timeout=1) # Short timeout
236236
with pytest.raises(MCPConnectionError) as exc_info:
237-
MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2)
237+
MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2, eager_connect=True)
238238

239239
# Check for platform-agnostic error message patterns
240240
error_message = str(exc_info.value)

integrations/mcp/tests/test_mcp_timeout_reconnection.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import subprocess
1414
import sys
1515
import tempfile
16+
import textwrap
1617
import time
1718
from unittest.mock import AsyncMock, MagicMock
1819

@@ -108,40 +109,39 @@ def test_real_sse_reconnection_after_server_restart(self):
108109
try:
109110
# Create server script with cross-platform signal handling
110111
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
111-
temp_file.write(
112-
f"""
113-
import sys
114-
import signal
115-
from mcp.server.fastmcp import FastMCP
116-
117-
# Handle shutdown signals gracefully (cross-platform)
118-
def signal_handler(signum, frame):
119-
sys.exit(0)
120-
121-
# Only set up signal handlers that exist on the platform
122-
if hasattr(signal, 'SIGTERM'):
123-
signal.signal(signal.SIGTERM, signal_handler)
124-
if hasattr(signal, 'SIGINT'):
125-
signal.signal(signal.SIGINT, signal_handler)
126-
127-
mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port})
128-
129-
@mcp.tool()
130-
def test_tool(message: str) -> str:
131-
return f"Server response: {{message}}"
132-
133-
if __name__ == "__main__":
134-
try:
135-
print(f"Starting server on port {port}", flush=True)
136-
mcp.run(transport="sse")
137-
except (KeyboardInterrupt, SystemExit):
138-
print("Server shutting down gracefully", flush=True)
139-
sys.exit(0)
140-
except Exception as e:
141-
print(f"Server error: {{e}}", file=sys.stderr, flush=True)
142-
sys.exit(1)
143-
""".encode()
144-
)
112+
script_content = textwrap.dedent(f"""
113+
import sys
114+
import signal
115+
from mcp.server.fastmcp import FastMCP
116+
117+
# Handle shutdown signals gracefully (cross-platform)
118+
def signal_handler(signum, frame):
119+
sys.exit(0)
120+
121+
# Only set up signal handlers that exist on the platform
122+
if hasattr(signal, 'SIGTERM'):
123+
signal.signal(signal.SIGTERM, signal_handler)
124+
if hasattr(signal, 'SIGINT'):
125+
signal.signal(signal.SIGINT, signal_handler)
126+
127+
mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port})
128+
129+
@mcp.tool()
130+
def test_tool(message: str) -> str:
131+
return f"Server response: {{message}}"
132+
133+
if __name__ == "__main__":
134+
try:
135+
print(f"Starting server on port {port}", flush=True)
136+
mcp.run(transport="sse")
137+
except (KeyboardInterrupt, SystemExit):
138+
print("Server shutting down gracefully", flush=True)
139+
sys.exit(0)
140+
except Exception as e:
141+
print(f"Server error: {{e}}", file=sys.stderr, flush=True)
142+
sys.exit(1)
143+
""").strip()
144+
temp_file.write(script_content.encode())
145145
server_script_path = temp_file.name
146146

147147
# Start server

0 commit comments

Comments
 (0)