From 5a42eeafca1e9370cb16e217629cb5f3374ca9ca Mon Sep 17 00:00:00 2001 From: Parteek Date: Wed, 19 Feb 2025 13:00:11 +0530 Subject: [PATCH 1/5] Remove translation tool from README (#705) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 79551f013..250ed5539 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ You can run agents from CLI using two commands: `smolagent` and `webagent`. `smolagent` is a generalist command to run a multi-step `CodeAgent` that can be equipped with various tools. ```bash -smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search translation" +smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search" ``` Meanwhile `webagent` is a specific web-browsing agent using [helium](https://github.com/mherrmann/helium) (read more [here](https://github.com/huggingface/smolagents/blob/main/src/smolagents/vision_web_browser.py)). From ca5885540c5780668bdbea7d0fdb8e45d8c90fdd Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 19 Feb 2025 11:00:53 +0100 Subject: [PATCH 2/5] Fix test docs (#701) --- tests/test_all_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 7dcbf5838..6df8dc76c 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -94,7 +94,7 @@ def setup_class(cls): load_dotenv() - cls.md_files = list(cls.docs_dir.rglob("*.md")) + cls.md_files = list(cls.docs_dir.rglob("*.mdx")) if not cls.md_files: raise ValueError(f"No markdown files found in {cls.docs_dir}") From 516f23878ba930d0bd2a15ab2ab991ef572a6be8 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 19 Feb 2025 11:01:01 +0100 Subject: [PATCH 3/5] Fix SpeechToTextTool (#706) --- src/smolagents/default_tools.py | 3 ++- tests/test_tool_validation.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 2d0ac5880..2ea7834f6 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional -from .agent_types import AgentAudio from .local_python_executor import ( BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, @@ -280,6 +279,8 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) def encode(self, audio): + from .agent_types import AgentAudio + audio = AgentAudio(audio).to_raw() return self.pre_processor(audio, return_tensors="pt") diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index 16cbe21d7..0a7c6c04a 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -1,9 +1,9 @@ import pytest -from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool +from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool from smolagents.tool_validation import validate_tool_attributes -@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool]) +@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool]) def test_validate_tool_attributes(tool_class): assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool" From e0345cc9e020892b77902576704f4f246b7c26b2 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Wed, 19 Feb 2025 17:02:26 +0100 Subject: [PATCH 4/5] Fixed typo in Telemetry docs (#710) --- docs/source/en/tutorials/inspect_runs.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/tutorials/inspect_runs.mdx b/docs/source/en/tutorials/inspect_runs.mdx index 113d2e56e..4ade8427b 100644 --- a/docs/source/en/tutorials/inspect_runs.mdx +++ b/docs/source/en/tutorials/inspect_runs.mdx @@ -97,7 +97,7 @@ You can then navigate to `http://0.0.0.0:6006/projects/` to inspect your run! -You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could be have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it? +You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it? ## Setting up telemetry with Langfuse From 7913e9827af6aa071524fd5e5594150ac719e4cb Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 19 Feb 2025 17:08:45 +0100 Subject: [PATCH 5/5] Fix validate_tool_attributes for non-literal defaults (#711) --- src/smolagents/tool_validation.py | 57 ++++++++++--------- tests/test_tool_validation.py | 95 ++++++++++++++++++++++++++++++- tests/test_tools.py | 2 +- 3 files changed, 125 insertions(+), 29 deletions(-) diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index b8665c270..125e68993 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -1,6 +1,6 @@ import ast import builtins -import inspect +from itertools import zip_longest from typing import Set from .utils import BASE_BUILTIN_MODULES, get_source @@ -157,41 +157,19 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: Raises all errors encountered, if no error returns None. """ - errors = [] - - source = get_source(cls) - - tree = ast.parse(source) - - if not isinstance(tree.body[0], ast.ClassDef): - raise ValueError("Source code must define a class") - - # Check that __init__ method only has arguments with defaults - if not cls.__init__.__qualname__ == "Tool.__init__": - sig = inspect.signature(cls.__init__) - non_default_params = [ - arg_name - for arg_name, param in sig.parameters.items() - if arg_name != "self" - and param.default == inspect.Parameter.empty - and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs - ] - if non_default_params: - errors.append( - f"This tool has required arguments in __init__: {non_default_params}. " - "All parameters of __init__ must have default values!" - ) - - class_node = tree.body[0] class ClassLevelChecker(ast.NodeVisitor): def __init__(self): self.imported_names = set() self.complex_attributes = set() self.class_attributes = set() + self.non_defaults = set() + self.non_literal_defaults = set() self.in_method = False def visit_FunctionDef(self, node): + if node.name == "__init__": + self._check_init_function_parameters(node) old_context = self.in_method self.in_method = True self.generic_visit(node) @@ -214,14 +192,39 @@ def visit_Assign(self, node): if isinstance(target, ast.Name): self.complex_attributes.add(target.id) + def _check_init_function_parameters(self, node): + # Check defaults in parameters + for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))): + if default is None: + if arg.arg != "self": + self.non_defaults.add(arg.arg) + elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)): + self.non_literal_defaults.add(arg.arg) + class_level_checker = ClassLevelChecker() + source = get_source(cls) + tree = ast.parse(source) + class_node = tree.body[0] + if not isinstance(class_node, ast.ClassDef): + raise ValueError("Source code must define a class") class_level_checker.visit(class_node) + errors = [] if class_level_checker.complex_attributes: errors.append( f"Complex attributes should be defined in __init__, not as class attributes: " f"{', '.join(class_level_checker.complex_attributes)}" ) + if class_level_checker.non_defaults: + errors.append( + f"Parameters in __init__ must have default values, found required parameters: " + f"{', '.join(class_level_checker.non_defaults)}" + ) + if class_level_checker.non_literal_defaults: + errors.append( + f"Parameters in __init__ must have literal default values, found non-literal defaults: " + f"{', '.join(class_level_checker.non_literal_defaults)}" + ) # Run checks on all methods for node in class_node.body: diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index 0a7c6c04a..f3a94ded2 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -2,8 +2,101 @@ from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool from smolagents.tool_validation import validate_tool_attributes +from smolagents.tools import Tool + + +UNDEFINED_VARIABLE = "undefined_variable" @pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool]) -def test_validate_tool_attributes(tool_class): +def test_validate_tool_attributes_with_default_tools(tool_class): assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool" + + +class ValidTool(Tool): + name = "valid_tool" + description = "A valid tool" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + simple_attr = "string" + dict_attr = {"key": "value"} + + def __init__(self, optional_param="default"): + super().__init__() + self.param = optional_param + + def forward(self, input: str) -> str: + return input.upper() + + +def test_validate_tool_attributes_valid(): + assert validate_tool_attributes(ValidTool) is None + + +class InvalidToolComplexAttrs(Tool): + name = "invalid_tool" + description = "Tool with complex class attributes" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + complex_attr = [x for x in range(3)] # Complex class attribute + + def __init__(self): + super().__init__() + + def forward(self, input: str) -> str: + return input + + +class InvalidToolRequiredParams(Tool): + name = "invalid_tool" + description = "Tool with required params" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def __init__(self, required_param, kwarg1=1): # No default value + super().__init__() + self.param = required_param + + def forward(self, input: str) -> str: + return input + + +class InvalidToolNonLiteralDefaultParam(Tool): + name = "invalid_tool" + description = "Tool with non-literal default parameter value" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal + super().__init__() + self.default_param = default_param + + def forward(self, input: str) -> str: + return input + + +class InvalidToolUndefinedNames(Tool): + name = "invalid_tool" + description = "Tool with undefined names" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def forward(self, input: str) -> str: + return UNDEFINED_VARIABLE # Undefined name + + +@pytest.mark.parametrize( + "tool_class, expected_error", + [ + (InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"), + (InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"), + ( + InvalidToolNonLiteralDefaultParam, + "Parameters in __init__ must have literal default values, found non-literal defaults", + ), + (InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"), + ], +) +def test_validate_tool_attributes_exceptions(tool_class, expected_error): + with pytest.raises(ValueError, match=expected_error): + validate_tool_attributes(tool_class) diff --git a/tests/test_tools.py b/tests/test_tools.py index cb8a8eeaa..4ac48e07d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -235,7 +235,7 @@ def forward(self, string_input: str) -> str: fail_tool = FailTool("dummy_url") with pytest.raises(Exception) as e: fail_tool.to_dict() - assert "All parameters of __init__ must have default values!" in str(e) + assert "Parameters in __init__ must have default values, found required parameters" in str(e) class PassTool(Tool): name = "specific"