diff --git a/README.md b/README.md index 79551f013..250ed5539 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ You can run agents from CLI using two commands: `smolagent` and `webagent`. `smolagent` is a generalist command to run a multi-step `CodeAgent` that can be equipped with various tools. ```bash -smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search translation" +smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search" ``` Meanwhile `webagent` is a specific web-browsing agent using [helium](https://github.com/mherrmann/helium) (read more [here](https://github.com/huggingface/smolagents/blob/main/src/smolagents/vision_web_browser.py)). diff --git a/docs/source/en/tutorials/inspect_runs.mdx b/docs/source/en/tutorials/inspect_runs.mdx index 113d2e56e..4ade8427b 100644 --- a/docs/source/en/tutorials/inspect_runs.mdx +++ b/docs/source/en/tutorials/inspect_runs.mdx @@ -97,7 +97,7 @@ You can then navigate to `http://0.0.0.0:6006/projects/` to inspect your run! -You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could be have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it? +You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it? ## Setting up telemetry with Langfuse diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 2d0ac5880..2ea7834f6 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional -from .agent_types import AgentAudio from .local_python_executor import ( BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, @@ -280,6 +279,8 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) def encode(self, audio): + from .agent_types import AgentAudio + audio = AgentAudio(audio).to_raw() return self.pre_processor(audio, return_tensors="pt") diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index b8665c270..125e68993 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -1,6 +1,6 @@ import ast import builtins -import inspect +from itertools import zip_longest from typing import Set from .utils import BASE_BUILTIN_MODULES, get_source @@ -157,41 +157,19 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: Raises all errors encountered, if no error returns None. """ - errors = [] - - source = get_source(cls) - - tree = ast.parse(source) - - if not isinstance(tree.body[0], ast.ClassDef): - raise ValueError("Source code must define a class") - - # Check that __init__ method only has arguments with defaults - if not cls.__init__.__qualname__ == "Tool.__init__": - sig = inspect.signature(cls.__init__) - non_default_params = [ - arg_name - for arg_name, param in sig.parameters.items() - if arg_name != "self" - and param.default == inspect.Parameter.empty - and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs - ] - if non_default_params: - errors.append( - f"This tool has required arguments in __init__: {non_default_params}. " - "All parameters of __init__ must have default values!" - ) - - class_node = tree.body[0] class ClassLevelChecker(ast.NodeVisitor): def __init__(self): self.imported_names = set() self.complex_attributes = set() self.class_attributes = set() + self.non_defaults = set() + self.non_literal_defaults = set() self.in_method = False def visit_FunctionDef(self, node): + if node.name == "__init__": + self._check_init_function_parameters(node) old_context = self.in_method self.in_method = True self.generic_visit(node) @@ -214,14 +192,39 @@ def visit_Assign(self, node): if isinstance(target, ast.Name): self.complex_attributes.add(target.id) + def _check_init_function_parameters(self, node): + # Check defaults in parameters + for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))): + if default is None: + if arg.arg != "self": + self.non_defaults.add(arg.arg) + elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)): + self.non_literal_defaults.add(arg.arg) + class_level_checker = ClassLevelChecker() + source = get_source(cls) + tree = ast.parse(source) + class_node = tree.body[0] + if not isinstance(class_node, ast.ClassDef): + raise ValueError("Source code must define a class") class_level_checker.visit(class_node) + errors = [] if class_level_checker.complex_attributes: errors.append( f"Complex attributes should be defined in __init__, not as class attributes: " f"{', '.join(class_level_checker.complex_attributes)}" ) + if class_level_checker.non_defaults: + errors.append( + f"Parameters in __init__ must have default values, found required parameters: " + f"{', '.join(class_level_checker.non_defaults)}" + ) + if class_level_checker.non_literal_defaults: + errors.append( + f"Parameters in __init__ must have literal default values, found non-literal defaults: " + f"{', '.join(class_level_checker.non_literal_defaults)}" + ) # Run checks on all methods for node in class_node.body: diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 7dcbf5838..6df8dc76c 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -94,7 +94,7 @@ def setup_class(cls): load_dotenv() - cls.md_files = list(cls.docs_dir.rglob("*.md")) + cls.md_files = list(cls.docs_dir.rglob("*.mdx")) if not cls.md_files: raise ValueError(f"No markdown files found in {cls.docs_dir}") diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index 16cbe21d7..f3a94ded2 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -1,9 +1,102 @@ import pytest -from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool +from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool from smolagents.tool_validation import validate_tool_attributes +from smolagents.tools import Tool -@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool]) -def test_validate_tool_attributes(tool_class): +UNDEFINED_VARIABLE = "undefined_variable" + + +@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool]) +def test_validate_tool_attributes_with_default_tools(tool_class): assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool" + + +class ValidTool(Tool): + name = "valid_tool" + description = "A valid tool" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + simple_attr = "string" + dict_attr = {"key": "value"} + + def __init__(self, optional_param="default"): + super().__init__() + self.param = optional_param + + def forward(self, input: str) -> str: + return input.upper() + + +def test_validate_tool_attributes_valid(): + assert validate_tool_attributes(ValidTool) is None + + +class InvalidToolComplexAttrs(Tool): + name = "invalid_tool" + description = "Tool with complex class attributes" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + complex_attr = [x for x in range(3)] # Complex class attribute + + def __init__(self): + super().__init__() + + def forward(self, input: str) -> str: + return input + + +class InvalidToolRequiredParams(Tool): + name = "invalid_tool" + description = "Tool with required params" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def __init__(self, required_param, kwarg1=1): # No default value + super().__init__() + self.param = required_param + + def forward(self, input: str) -> str: + return input + + +class InvalidToolNonLiteralDefaultParam(Tool): + name = "invalid_tool" + description = "Tool with non-literal default parameter value" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal + super().__init__() + self.default_param = default_param + + def forward(self, input: str) -> str: + return input + + +class InvalidToolUndefinedNames(Tool): + name = "invalid_tool" + description = "Tool with undefined names" + inputs = {"input": {"type": "string", "description": "input"}} + output_type = "string" + + def forward(self, input: str) -> str: + return UNDEFINED_VARIABLE # Undefined name + + +@pytest.mark.parametrize( + "tool_class, expected_error", + [ + (InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"), + (InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"), + ( + InvalidToolNonLiteralDefaultParam, + "Parameters in __init__ must have literal default values, found non-literal defaults", + ), + (InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"), + ], +) +def test_validate_tool_attributes_exceptions(tool_class, expected_error): + with pytest.raises(ValueError, match=expected_error): + validate_tool_attributes(tool_class) diff --git a/tests/test_tools.py b/tests/test_tools.py index cb8a8eeaa..4ac48e07d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -235,7 +235,7 @@ def forward(self, string_input: str) -> str: fail_tool = FailTool("dummy_url") with pytest.raises(Exception) as e: fail_tool.to_dict() - assert "All parameters of __init__ must have default values!" in str(e) + assert "Parameters in __init__ must have default values, found required parameters" in str(e) class PassTool(Tool): name = "specific"