44
55import inspect
66from dataclasses import dataclass
7- from typing import Any , Optional , Union
7+ from typing import Any , Optional , Union , cast
88
99from haystack import logging , tracing
1010from haystack .components .generators .chat .types import ChatGenerator
2525from haystack .dataclasses import ChatMessage , ChatRole
2626from haystack .dataclasses .breakpoints import AgentBreakpoint , AgentSnapshot , PipelineSnapshot , ToolBreakpoint
2727from haystack .dataclasses .streaming_chunk import StreamingCallbackT , select_streaming_callback
28- from haystack .tools import Tool , Toolset , deserialize_tools_or_toolset_inplace , serialize_tools_or_toolset
28+ from haystack .tools import (
29+ Tool ,
30+ Toolset ,
31+ deserialize_tools_or_toolset_inplace ,
32+ flatten_tools_or_toolsets ,
33+ serialize_tools_or_toolset ,
34+ )
2935from haystack .utils import _deserialize_value_with_schema
3036from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
3137from haystack .utils .deserialization import deserialize_chatgenerator_inplace
@@ -97,7 +103,7 @@ def __init__(
97103 self ,
98104 * ,
99105 chat_generator : ChatGenerator ,
100- tools : Optional [Union [list [Tool ], Toolset ]] = None ,
106+ tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset ]] = None ,
101107 system_prompt : Optional [str ] = None ,
102108 exit_conditions : Optional [list [str ]] = None ,
103109 state_schema : Optional [dict [str , Any ]] = None ,
@@ -110,7 +116,7 @@ def __init__(
110116 Initialize the agent component.
111117
112118 :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
113- :param tools: List of Tool objects or a Toolset that the agent can use.
119+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
114120 :param system_prompt: System prompt for the agent.
115121 :param exit_conditions: List of conditions that will cause the agent to return.
116122 Can include "text" if the agent should return when it generates a message without tool calls,
@@ -134,7 +140,7 @@ def __init__(
134140 "The Agent component requires a chat generator that supports tools."
135141 )
136142
137- valid_exits = ["text" ] + [tool .name for tool in tools or [] ]
143+ valid_exits = ["text" ] + [tool .name for tool in flatten_tools_or_toolsets ( tools ) ]
138144 if exit_conditions is None :
139145 exit_conditions = ["text" ]
140146 if not all (condition in valid_exits for condition in exit_conditions ):
@@ -259,7 +265,7 @@ def _initialize_fresh_execution(
259265 requires_async : bool ,
260266 * ,
261267 system_prompt : Optional [str ] = None ,
262- tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None ,
268+ tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset , list [str ]]] = None ,
263269 ** kwargs ,
264270 ) -> _ExecutionContext :
265271 """
@@ -302,8 +308,8 @@ def _initialize_fresh_execution(
302308 )
303309
304310 def _select_tools (
305- self , tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None
306- ) -> Union [list [Tool ], Toolset ]:
311+ self , tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset , list [str ]]] = None
312+ ) -> Union [list [Union [ Tool , Toolset ] ], Toolset ]:
307313 """
308314 Select tools for the current run based on the provided tools parameter.
309315
@@ -314,32 +320,43 @@ def _select_tools(
314320 or if any provided tool name is not valid.
315321 :raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
316322 """
317- selected_tools : Union [ list [ Tool ], Toolset ] = self . tools
318- if isinstance ( tools , Toolset ) or isinstance ( tools , list ) and all ( isinstance ( t , Tool ) for t in tools ):
319- selected_tools = tools # type: ignore[assignment] # mypy thinks this could still be list[str]
320- elif isinstance (tools , list ) and all (isinstance (t , str ) for t in tools ):
323+ if tools is None :
324+ return self . tools
325+
326+ if isinstance (tools , list ) and all (isinstance (t , str ) for t in tools ):
321327 if not self .tools :
322328 raise ValueError ("No tools were configured for the Agent at initialization." )
323- selected_tool_names : list [str ] = tools # type: ignore[assignment] # mypy thinks this could still be list[Tool] or Toolset
324- valid_tool_names = {tool .name for tool in self .tools }
329+ available_tools = flatten_tools_or_toolsets (self .tools )
330+ selected_tool_names = cast (list [str ], tools )
331+ valid_tool_names = {tool .name for tool in available_tools }
325332 invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names }
326333 if invalid_tool_names :
327334 raise ValueError (
328335 f"The following tool names are not valid: { invalid_tool_names } . "
329336 f"Valid tool names are: { valid_tool_names } ."
330337 )
331- selected_tools = [tool for tool in self .tools if tool .name in selected_tool_names ]
332- elif tools is not None :
333- raise TypeError ("tools must be a list of Tool objects, a Toolset, or a list of tool names (strings)." )
334- return selected_tools
338+ return [tool for tool in available_tools if tool .name in selected_tool_names ]
339+
340+ if isinstance (tools , Toolset ):
341+ return tools
342+
343+ if isinstance (tools , list ) and tools and isinstance (tools [0 ], Toolset ):
344+ return cast (list [Union [Tool , Toolset ]], tools )
345+
346+ if isinstance (tools , list ) and all (isinstance (t , Tool ) for t in tools ):
347+ return cast (list [Union [Tool , Toolset ]], tools )
348+
349+ raise TypeError (
350+ "tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
351+ )
335352
336353 def _initialize_from_snapshot (
337354 self ,
338355 snapshot : AgentSnapshot ,
339356 streaming_callback : Optional [StreamingCallbackT ],
340357 requires_async : bool ,
341358 * ,
342- tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None ,
359+ tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset , list [str ]]] = None ,
343360 ) -> _ExecutionContext :
344361 """
345362 Initialize execution context from an AgentSnapshot.
@@ -459,7 +476,7 @@ def run( # noqa: PLR0915
459476 break_point : Optional [AgentBreakpoint ] = None ,
460477 snapshot : Optional [AgentSnapshot ] = None ,
461478 system_prompt : Optional [str ] = None ,
462- tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None ,
479+ tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset , list [str ]]] = None ,
463480 ** kwargs : Any ,
464481 ) -> dict [str , Any ]:
465482 """
@@ -616,7 +633,7 @@ async def run_async(
616633 break_point : Optional [AgentBreakpoint ] = None ,
617634 snapshot : Optional [AgentSnapshot ] = None ,
618635 system_prompt : Optional [str ] = None ,
619- tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None ,
636+ tools : Optional [Union [list [Union [ Tool , Toolset ] ], Toolset , list [str ]]] = None ,
620637 ** kwargs : Any ,
621638 ) -> dict [str , Any ]:
622639 """
0 commit comments