Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast

from haystack import logging, tracing
from haystack.components.generators.chat.types import ChatGenerator
Expand All @@ -25,7 +25,14 @@
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, PipelineSnapshot, ToolBreakpoint
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
from haystack.tools import (
Tool,
Toolset,
ToolsType,
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
)
from haystack.utils import _deserialize_value_with_schema
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
Expand Down Expand Up @@ -97,7 +104,7 @@ def __init__(
self,
*,
chat_generator: ChatGenerator,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
system_prompt: Optional[str] = None,
exit_conditions: Optional[list[str]] = None,
state_schema: Optional[dict[str, Any]] = None,
Expand All @@ -110,7 +117,7 @@ def __init__(
Initialize the agent component.

:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
:param tools: List of Tool objects or a Toolset that the agent can use.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
:param system_prompt: System prompt for the agent.
:param exit_conditions: List of conditions that will cause the agent to return.
Can include "text" if the agent should return when it generates a message without tool calls,
Expand All @@ -134,7 +141,7 @@ def __init__(
"The Agent component requires a chat generator that supports tools."
)

valid_exits = ["text"] + [tool.name for tool in tools or []]
valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
if exit_conditions is None:
exit_conditions = ["text"]
if not all(condition in valid_exits for condition in exit_conditions):
Expand Down Expand Up @@ -259,7 +266,7 @@ def _initialize_fresh_execution(
requires_async: bool,
*,
system_prompt: Optional[str] = None,
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
**kwargs,
) -> _ExecutionContext:
"""
Expand Down Expand Up @@ -301,9 +308,7 @@ def _initialize_fresh_execution(
tool_invoker_inputs=tool_invoker_inputs,
)

def _select_tools(
self, tools: Optional[Union[list[Tool], Toolset, list[str]]] = None
) -> Union[list[Tool], Toolset]:
def _select_tools(self, tools: Optional[Union[ToolsType, list[str]]] = None) -> ToolsType:
"""
Select tools for the current run based on the provided tools parameter.

Expand All @@ -314,32 +319,40 @@ def _select_tools(
or if any provided tool name is not valid.
:raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
"""
selected_tools: Union[list[Tool], Toolset] = self.tools
if isinstance(tools, Toolset) or isinstance(tools, list) and all(isinstance(t, Tool) for t in tools):
selected_tools = tools # type: ignore[assignment] # mypy thinks this could still be list[str]
elif isinstance(tools, list) and all(isinstance(t, str) for t in tools):
if tools is None:
return self.tools

if isinstance(tools, list) and all(isinstance(t, str) for t in tools):
if not self.tools:
raise ValueError("No tools were configured for the Agent at initialization.")
selected_tool_names: list[str] = tools # type: ignore[assignment] # mypy thinks this could still be list[Tool] or Toolset
valid_tool_names = {tool.name for tool in self.tools}
available_tools = flatten_tools_or_toolsets(self.tools)
selected_tool_names = cast(list[str], tools) # mypy thinks this could still be list[Tool] or Toolset
valid_tool_names = {tool.name for tool in available_tools}
invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names}
if invalid_tool_names:
raise ValueError(
f"The following tool names are not valid: {invalid_tool_names}. "
f"Valid tool names are: {valid_tool_names}."
)
selected_tools = [tool for tool in self.tools if tool.name in selected_tool_names]
elif tools is not None:
raise TypeError("tools must be a list of Tool objects, a Toolset, or a list of tool names (strings).")
return selected_tools
return [tool for tool in available_tools if tool.name in selected_tool_names]

if isinstance(tools, Toolset):
return tools

if isinstance(tools, list):
return cast(list[Union[Tool, Toolset]], tools) # mypy can't narrow the Union type from isinstance check

raise TypeError(
"tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
)

def _initialize_from_snapshot(
self,
snapshot: AgentSnapshot,
streaming_callback: Optional[StreamingCallbackT],
requires_async: bool,
*,
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
) -> _ExecutionContext:
"""
Initialize execution context from an AgentSnapshot.
Expand Down Expand Up @@ -459,7 +472,7 @@ def run( # noqa: PLR0915
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None,
system_prompt: Optional[str] = None,
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -616,7 +629,7 @@ async def run_async(
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None,
system_prompt: Optional[str] = None,
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
Expand Down
11 changes: 5 additions & 6 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.tools import (
Tool,
Toolset,
ToolsType,
_check_duplicate_tool_names,
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
max_retries: Optional[int] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
default_headers: Optional[dict[str, str]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
tools_strict: bool = False,
*,
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
Expand Down Expand Up @@ -138,8 +138,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
the `response_format` must be a JSON schema and not a Pydantic model.
:param default_headers: Default headers to use for the AzureOpenAI client.
:param tools:
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
Expand Down Expand Up @@ -179,7 +178,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider
self.http_client_kwargs = http_client_kwargs
_check_duplicate_tool_names(list(tools or []))
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
self.tools = tools
self.tools_strict = tools_strict

Expand Down
16 changes: 8 additions & 8 deletions haystack/components/generators/chat/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from __future__ import annotations

import asyncio
from typing import Any, Union
from typing import Any, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat.types import ChatGenerator
from haystack.dataclasses import ChatMessage, StreamingCallbackT
from haystack.tools import Tool, Toolset
from haystack.tools import ToolsType
from haystack.utils.deserialization import deserialize_component_inplace

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +86,7 @@ def _run_single_sync( # pylint: disable=too-many-positional-arguments
gen: Any,
messages: list[ChatMessage],
generation_kwargs: Union[dict[str, Any], None],
tools: Union[list[Tool], Toolset, None],
tools: Optional[ToolsType],
streaming_callback: Union[StreamingCallbackT, None],
) -> dict[str, Any]:
return gen.run(
Expand All @@ -98,7 +98,7 @@ async def _run_single_async( # pylint: disable=too-many-positional-arguments
gen: Any,
messages: list[ChatMessage],
generation_kwargs: Union[dict[str, Any], None],
tools: Union[list[Tool], Toolset, None],
tools: Optional[ToolsType],
streaming_callback: Union[StreamingCallbackT, None],
) -> dict[str, Any]:
if hasattr(gen, "run_async") and callable(gen.run_async):
Expand All @@ -121,15 +121,15 @@ def run(
self,
messages: list[ChatMessage],
generation_kwargs: Union[dict[str, Any], None] = None,
tools: Union[list[Tool], Toolset, None] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Union[StreamingCallbackT, None] = None,
) -> dict[str, Any]:
"""
Execute chat generators sequentially until one succeeds.

:param messages: The conversation history as a list of ChatMessage instances.
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
:param tools: Optional Tool instances or Toolset for function calling capabilities.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for function calling capabilities.
:param streaming_callback: Optional callable for handling streaming responses.
:returns: A dictionary with:
- "replies": Generated ChatMessage instances from the first successful generator.
Expand Down Expand Up @@ -174,15 +174,15 @@ async def run_async(
self,
messages: list[ChatMessage],
generation_kwargs: Union[dict[str, Any], None] = None,
tools: Union[list[Tool], Toolset, None] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Union[StreamingCallbackT, None] = None,
) -> dict[str, Any]:
"""
Asynchronously execute chat generators sequentially until one succeeds.

:param messages: The conversation history as a list of ChatMessage instances.
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
:param tools: Optional Tool instances or Toolset for function calling capabilities.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for function calling capabilities.
:param streaming_callback: Optional callable for handling streaming responses.
:returns: A dictionary with:
- "replies": Generated ChatMessage instances from the first successful generator.
Expand Down
34 changes: 14 additions & 20 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from haystack.dataclasses.streaming_chunk import FinishReason
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Toolset,
ToolsType,
_check_duplicate_tool_names,
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
Expand Down Expand Up @@ -93,17 +93,15 @@ def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[list["ChatCompletionOut
return tool_calls


def _convert_tools_to_hfapi_tools(
tools: Optional[Union[list[Tool], Toolset]],
) -> Optional[list["ChatCompletionInputTool"]]:
def _convert_tools_to_hfapi_tools(tools: Optional[ToolsType]) -> Optional[list["ChatCompletionInputTool"]]:
if not tools:
return None

# huggingface_hub<0.31.0 uses "arguments", huggingface_hub>=0.31.0 uses "parameters"
parameters_name = "arguments" if hasattr(ChatCompletionInputFunctionDefinition, "arguments") else "parameters"

hf_tools = []
for tool in tools:
for tool in flatten_tools_or_toolsets(tools):
hf_tools_args = {"name": tool.name, "description": tool.description, parameters_name: tool.parameters}

hf_tools.append(
Expand Down Expand Up @@ -298,7 +296,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs: Optional[dict[str, Any]] = None,
stop_words: Optional[list[str]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
Expand Down Expand Up @@ -328,10 +326,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools or a Toolset for which the model can prepare calls.
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
unexpected behavior.
"""

huggingface_hub_import.check()
Expand Down Expand Up @@ -364,7 +362,7 @@ def __init__( # pylint: disable=too-many-positional-arguments

if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(list(tools or []))
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
Expand Down Expand Up @@ -423,7 +421,7 @@ def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Expand Down Expand Up @@ -452,7 +450,8 @@ def run(
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(list(tools or []))
flat_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flat_tools)

# validate and select the streaming callback
streaming_callback = select_streaming_callback(
Expand All @@ -462,9 +461,6 @@ def run(
if streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)

if tools and isinstance(tools, Toolset):
tools = list(tools)

hf_tools = _convert_tools_to_hfapi_tools(tools)

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
Expand All @@ -474,7 +470,7 @@ async def run_async(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Expand Down Expand Up @@ -506,17 +502,15 @@ async def run_async(
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(list(tools or []))
flat_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flat_tools)

# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)

if streaming_callback:
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)

if tools and isinstance(tools, Toolset):
tools = list(tools)

hf_tools = _convert_tools_to_hfapi_tools(tools)

return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
Expand Down
Loading
Loading