Skip to content

Commit a8ae92f

Browse files
committed
Update tools param to Optional[Union[list[Union[Tool, Toolset]], Toolset]]
1 parent fe60c76 commit a8ae92f

File tree

15 files changed

+685
-95
lines changed

15 files changed

+685
-95
lines changed

haystack/components/agents/agent.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from dataclasses import dataclass
7-
from typing import Any, Optional, Union
7+
from typing import Any, Optional, Union, cast
88

99
from haystack import logging, tracing
1010
from haystack.components.generators.chat.types import ChatGenerator
@@ -25,7 +25,13 @@
2525
from haystack.dataclasses import ChatMessage, ChatRole
2626
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, PipelineSnapshot, ToolBreakpoint
2727
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
28-
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
28+
from haystack.tools import (
29+
Tool,
30+
Toolset,
31+
deserialize_tools_or_toolset_inplace,
32+
flatten_tools_or_toolsets,
33+
serialize_tools_or_toolset,
34+
)
2935
from haystack.utils import _deserialize_value_with_schema
3036
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
3137
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
@@ -97,7 +103,7 @@ def __init__(
97103
self,
98104
*,
99105
chat_generator: ChatGenerator,
100-
tools: Optional[Union[list[Tool], Toolset]] = None,
106+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
101107
system_prompt: Optional[str] = None,
102108
exit_conditions: Optional[list[str]] = None,
103109
state_schema: Optional[dict[str, Any]] = None,
@@ -110,7 +116,7 @@ def __init__(
110116
Initialize the agent component.
111117
112118
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
113-
:param tools: List of Tool objects or a Toolset that the agent can use.
119+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
114120
:param system_prompt: System prompt for the agent.
115121
:param exit_conditions: List of conditions that will cause the agent to return.
116122
Can include "text" if the agent should return when it generates a message without tool calls,
@@ -134,7 +140,7 @@ def __init__(
134140
"The Agent component requires a chat generator that supports tools."
135141
)
136142

137-
valid_exits = ["text"] + [tool.name for tool in tools or []]
143+
valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
138144
if exit_conditions is None:
139145
exit_conditions = ["text"]
140146
if not all(condition in valid_exits for condition in exit_conditions):
@@ -259,7 +265,7 @@ def _initialize_fresh_execution(
259265
requires_async: bool,
260266
*,
261267
system_prompt: Optional[str] = None,
262-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
268+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset, list[str]]] = None,
263269
**kwargs,
264270
) -> _ExecutionContext:
265271
"""
@@ -302,8 +308,8 @@ def _initialize_fresh_execution(
302308
)
303309

304310
def _select_tools(
305-
self, tools: Optional[Union[list[Tool], Toolset, list[str]]] = None
306-
) -> Union[list[Tool], Toolset]:
311+
self, tools: Optional[Union[list[Union[Tool, Toolset]], Toolset, list[str]]] = None
312+
) -> Union[list[Union[Tool, Toolset]], Toolset]:
307313
"""
308314
Select tools for the current run based on the provided tools parameter.
309315
@@ -314,32 +320,43 @@ def _select_tools(
314320
or if any provided tool name is not valid.
315321
:raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
316322
"""
317-
selected_tools: Union[list[Tool], Toolset] = self.tools
318-
if isinstance(tools, Toolset) or isinstance(tools, list) and all(isinstance(t, Tool) for t in tools):
319-
selected_tools = tools # type: ignore[assignment] # mypy thinks this could still be list[str]
320-
elif isinstance(tools, list) and all(isinstance(t, str) for t in tools):
323+
if tools is None:
324+
return self.tools
325+
326+
if isinstance(tools, list) and all(isinstance(t, str) for t in tools):
321327
if not self.tools:
322328
raise ValueError("No tools were configured for the Agent at initialization.")
323-
selected_tool_names: list[str] = tools # type: ignore[assignment] # mypy thinks this could still be list[Tool] or Toolset
324-
valid_tool_names = {tool.name for tool in self.tools}
329+
available_tools = flatten_tools_or_toolsets(self.tools)
330+
selected_tool_names = cast(list[str], tools)
331+
valid_tool_names = {tool.name for tool in available_tools}
325332
invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names}
326333
if invalid_tool_names:
327334
raise ValueError(
328335
f"The following tool names are not valid: {invalid_tool_names}. "
329336
f"Valid tool names are: {valid_tool_names}."
330337
)
331-
selected_tools = [tool for tool in self.tools if tool.name in selected_tool_names]
332-
elif tools is not None:
333-
raise TypeError("tools must be a list of Tool objects, a Toolset, or a list of tool names (strings).")
334-
return selected_tools
338+
return [tool for tool in available_tools if tool.name in selected_tool_names]
339+
340+
if isinstance(tools, Toolset):
341+
return tools
342+
343+
if isinstance(tools, list) and tools and isinstance(tools[0], Toolset):
344+
return cast(list[Union[Tool, Toolset]], tools)
345+
346+
if isinstance(tools, list) and all(isinstance(t, Tool) for t in tools):
347+
return cast(list[Union[Tool, Toolset]], tools)
348+
349+
raise TypeError(
350+
"tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
351+
)
335352

336353
def _initialize_from_snapshot(
337354
self,
338355
snapshot: AgentSnapshot,
339356
streaming_callback: Optional[StreamingCallbackT],
340357
requires_async: bool,
341358
*,
342-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
359+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset, list[str]]] = None,
343360
) -> _ExecutionContext:
344361
"""
345362
Initialize execution context from an AgentSnapshot.
@@ -459,7 +476,7 @@ def run( # noqa: PLR0915
459476
break_point: Optional[AgentBreakpoint] = None,
460477
snapshot: Optional[AgentSnapshot] = None,
461478
system_prompt: Optional[str] = None,
462-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
479+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset, list[str]]] = None,
463480
**kwargs: Any,
464481
) -> dict[str, Any]:
465482
"""
@@ -616,7 +633,7 @@ async def run_async(
616633
break_point: Optional[AgentBreakpoint] = None,
617634
snapshot: Optional[AgentSnapshot] = None,
618635
system_prompt: Optional[str] = None,
619-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
636+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset, list[str]]] = None,
620637
**kwargs: Any,
621638
) -> dict[str, Any]:
622639
"""

haystack/components/generators/chat/azure.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Toolset,
1818
_check_duplicate_tool_names,
1919
deserialize_tools_or_toolset_inplace,
20+
flatten_tools_or_toolsets,
2021
serialize_tools_or_toolset,
2122
)
2223
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@@ -84,7 +85,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
8485
max_retries: Optional[int] = None,
8586
generation_kwargs: Optional[dict[str, Any]] = None,
8687
default_headers: Optional[dict[str, str]] = None,
87-
tools: Optional[Union[list[Tool], Toolset]] = None,
88+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
8889
tools_strict: bool = False,
8990
*,
9091
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
@@ -138,8 +139,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
138139
the `response_format` must be a JSON schema and not a Pydantic model.
139140
:param default_headers: Default headers to use for the AzureOpenAI client.
140141
:param tools:
141-
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
142-
list of `Tool` objects or a `Toolset` instance.
142+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
143143
:param tools_strict:
144144
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
145145
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@@ -179,7 +179,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
179179
self.default_headers = default_headers or {}
180180
self.azure_ad_token_provider = azure_ad_token_provider
181181
self.http_client_kwargs = http_client_kwargs
182-
_check_duplicate_tool_names(list(tools or []))
182+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
183183
self.tools = tools
184184
self.tools_strict = tools_strict
185185

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Toolset,
2626
_check_duplicate_tool_names,
2727
deserialize_tools_or_toolset_inplace,
28+
flatten_tools_or_toolsets,
2829
serialize_tools_or_toolset,
2930
)
3031
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@@ -94,7 +95,7 @@ def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[list["ChatCompletionOut
9495

9596

9697
def _convert_tools_to_hfapi_tools(
97-
tools: Optional[Union[list[Tool], Toolset]],
98+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]],
9899
) -> Optional[list["ChatCompletionInputTool"]]:
99100
if not tools:
100101
return None
@@ -103,7 +104,7 @@ def _convert_tools_to_hfapi_tools(
103104
parameters_name = "arguments" if hasattr(ChatCompletionInputFunctionDefinition, "arguments") else "parameters"
104105

105106
hf_tools = []
106-
for tool in tools:
107+
for tool in flatten_tools_or_toolsets(tools):
107108
hf_tools_args = {"name": tool.name, "description": tool.description, parameters_name: tool.parameters}
108109

109110
hf_tools.append(
@@ -298,7 +299,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
298299
generation_kwargs: Optional[dict[str, Any]] = None,
299300
stop_words: Optional[list[str]] = None,
300301
streaming_callback: Optional[StreamingCallbackT] = None,
301-
tools: Optional[Union[list[Tool], Toolset]] = None,
302+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
302303
):
303304
"""
304305
Initialize the HuggingFaceAPIChatGenerator instance.
@@ -328,10 +329,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
328329
:param streaming_callback:
329330
An optional callable for handling streaming responses.
330331
:param tools:
331-
A list of tools or a Toolset for which the model can prepare calls.
332+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
332333
The chosen model should support tool/function calling, according to the model card.
333334
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
334-
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
335+
unexpected behavior.
335336
"""
336337

337338
huggingface_hub_import.check()
@@ -364,7 +365,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
364365

365366
if tools and streaming_callback is not None:
366367
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
367-
_check_duplicate_tool_names(list(tools or []))
368+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
368369

369370
# handle generation kwargs setup
370371
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
@@ -423,7 +424,7 @@ def run(
423424
self,
424425
messages: list[ChatMessage],
425426
generation_kwargs: Optional[dict[str, Any]] = None,
426-
tools: Optional[Union[list[Tool], Toolset]] = None,
427+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
427428
streaming_callback: Optional[StreamingCallbackT] = None,
428429
):
429430
"""
@@ -452,7 +453,8 @@ def run(
452453
tools = tools or self.tools
453454
if tools and self.streaming_callback:
454455
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
455-
_check_duplicate_tool_names(list(tools or []))
456+
flat_tools = flatten_tools_or_toolsets(tools)
457+
_check_duplicate_tool_names(flat_tools)
456458

457459
# validate and select the streaming callback
458460
streaming_callback = select_streaming_callback(
@@ -462,9 +464,6 @@ def run(
462464
if streaming_callback:
463465
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
464466

465-
if tools and isinstance(tools, Toolset):
466-
tools = list(tools)
467-
468467
hf_tools = _convert_tools_to_hfapi_tools(tools)
469468

470469
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
@@ -474,7 +473,7 @@ async def run_async(
474473
self,
475474
messages: list[ChatMessage],
476475
generation_kwargs: Optional[dict[str, Any]] = None,
477-
tools: Optional[Union[list[Tool], Toolset]] = None,
476+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
478477
streaming_callback: Optional[StreamingCallbackT] = None,
479478
):
480479
"""
@@ -506,17 +505,15 @@ async def run_async(
506505
tools = tools or self.tools
507506
if tools and self.streaming_callback:
508507
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
509-
_check_duplicate_tool_names(list(tools or []))
508+
flat_tools = flatten_tools_or_toolsets(tools)
509+
_check_duplicate_tool_names(flat_tools)
510510

511511
# validate and select the streaming callback
512512
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
513513

514514
if streaming_callback:
515515
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
516516

517-
if tools and isinstance(tools, Toolset):
518-
tools = list(tools)
519-
520517
hf_tools = _convert_tools_to_hfapi_tools(tools)
521518

522519
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)

0 commit comments

Comments
 (0)