Skip to content

Commit

Permalink
Fix and test validate_tools_and_managed_agents (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Feb 26, 2025
1 parent 84089bc commit 9498094
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,9 @@ def _setup_tools(self, tools, add_base_tools):
def _validate_tools_and_managed_agents(self, tools, managed_agents):
tool_and_managed_agent_names = [tool.name for tool in tools]
if managed_agents is not None:
for agent in managed_agents:
tool_and_managed_agent_names.append(agent.name)
for tool in agent.tools.values():
if tool.name != "final_answer":
tool_and_managed_agent_names.append(tool.name)
tool_and_managed_agent_names += [agent.name for agent in managed_agents]
if self.name:
tool_and_managed_agent_names.append(self.name)
if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
raise ValueError(
"Each tool or managed_agent should have a unique name! You passed these duplicate names: "
Expand Down
66 changes: 65 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import unittest
import uuid
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from unittest.mock import MagicMock

Expand All @@ -41,7 +42,7 @@
MessageRole,
TransformersModel,
)
from smolagents.tools import tool
from smolagents.tools import Tool, tool
from smolagents.utils import BASE_BUILTIN_MODULES


Expand Down Expand Up @@ -574,6 +575,24 @@ def forward(self, answer) -> str:
return answer + "CUSTOM"


class MockTool(Tool):
def __init__(self, name):
self.name = name
self.description = "Mock tool description"
self.inputs = {}
self.output_type = "string"

def forward(self):
return "Mock tool output"


class MockAgent:
def __init__(self, name, tools, description="Mock agent description"):
self.name = name
self.tools = {t.name: t for t in tools}
self.description = description


class TestMultiStepAgent:
def test_instantiation_disables_logging_to_terminal(self):
fake_model = MagicMock()
Expand Down Expand Up @@ -784,6 +803,51 @@ def test_provide_final_answer(self, images, expected_messages_list):
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content

@pytest.mark.parametrize(
"tools, managed_agents, name, expectation",
[
# Valid case: no duplicates
(
[MockTool("tool1"), MockTool("tool2")],
[MockAgent("agent1", [MockTool("tool3")])],
"test_agent",
does_not_raise(),
),
# Invalid case: duplicate tool names
([MockTool("tool1"), MockTool("tool1")], [], "test_agent", pytest.raises(ValueError)),
# Invalid case: tool name same as managed agent name
(
[MockTool("tool1")],
[MockAgent("tool1", [MockTool("final_answer")])],
"test_agent",
pytest.raises(ValueError),
),
# Valid case: tool name same as managed agent's tool name
([MockTool("tool1")], [MockAgent("agent1", [MockTool("tool1")])], "test_agent", does_not_raise()),
# Invalid case: duplicate managed agent name and managed agent tool name
([MockTool("tool1")], [], "tool1", pytest.raises(ValueError)),
# Valid case: duplicate tool names across managed agents
(
[MockTool("tool1")],
[
MockAgent("agent1", [MockTool("tool2"), MockTool("final_answer")]),
MockAgent("agent2", [MockTool("tool2"), MockTool("final_answer")]),
],
"test_agent",
does_not_raise(),
),
],
)
def test_validate_tools_and_managed_agents(self, tools, managed_agents, name, expectation):
fake_model = MagicMock()
with expectation:
MultiStepAgent(
tools=tools,
model=fake_model,
name=name,
managed_agents=managed_agents,
)


class TestCodeAgent:
@pytest.mark.parametrize("provide_run_summary", [False, True])
Expand Down

0 comments on commit 9498094

Please sign in to comment.