Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
colesmcintosh authored Feb 20, 2025
2 parents 8f941e6 + 7913e98 commit d2a734d
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 35 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ You can run agents from CLI using two commands: `smolagent` and `webagent`.
`smolagent` is a generalist command to run a multi-step `CodeAgent` that can be equipped with various tools.

```bash
smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search translation"
smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "HfApiModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search"
```

Meanwhile `webagent` is a specific web-browsing agent using [helium](https://github.com/mherrmann/helium) (read more [here](https://github.com/huggingface/smolagents/blob/main/src/smolagents/vision_web_browser.py)).
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tutorials/inspect_runs.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ You can then navigate to `http://0.0.0.0:6006/projects/` to inspect your run!

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/inspect_run_phoenix.png">

You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could be have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it?
You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it?

## Setting up telemetry with Langfuse

Expand Down
3 changes: 2 additions & 1 deletion src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

from .agent_types import AgentAudio
from .local_python_executor import (
BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS,
Expand Down Expand Up @@ -280,6 +279,8 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

def encode(self, audio):
from .agent_types import AgentAudio

audio = AgentAudio(audio).to_raw()
return self.pre_processor(audio, return_tensors="pt")

Expand Down
57 changes: 30 additions & 27 deletions src/smolagents/tool_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import builtins
import inspect
from itertools import zip_longest
from typing import Set

from .utils import BASE_BUILTIN_MODULES, get_source
Expand Down Expand Up @@ -157,41 +157,19 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
Raises all errors encountered, if no error returns None.
"""
errors = []

source = get_source(cls)

tree = ast.parse(source)

if not isinstance(tree.body[0], ast.ClassDef):
raise ValueError("Source code must define a class")

# Check that __init__ method only has arguments with defaults
if not cls.__init__.__qualname__ == "Tool.__init__":
sig = inspect.signature(cls.__init__)
non_default_params = [
arg_name
for arg_name, param in sig.parameters.items()
if arg_name != "self"
and param.default == inspect.Parameter.empty
and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs
]
if non_default_params:
errors.append(
f"This tool has required arguments in __init__: {non_default_params}. "
"All parameters of __init__ must have default values!"
)

class_node = tree.body[0]

class ClassLevelChecker(ast.NodeVisitor):
def __init__(self):
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
self.non_defaults = set()
self.non_literal_defaults = set()
self.in_method = False

def visit_FunctionDef(self, node):
if node.name == "__init__":
self._check_init_function_parameters(node)
old_context = self.in_method
self.in_method = True
self.generic_visit(node)
Expand All @@ -214,14 +192,39 @@ def visit_Assign(self, node):
if isinstance(target, ast.Name):
self.complex_attributes.add(target.id)

def _check_init_function_parameters(self, node):
# Check defaults in parameters
for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))):
if default is None:
if arg.arg != "self":
self.non_defaults.add(arg.arg)
elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)):
self.non_literal_defaults.add(arg.arg)

class_level_checker = ClassLevelChecker()
source = get_source(cls)
tree = ast.parse(source)
class_node = tree.body[0]
if not isinstance(class_node, ast.ClassDef):
raise ValueError("Source code must define a class")
class_level_checker.visit(class_node)

errors = []
if class_level_checker.complex_attributes:
errors.append(
f"Complex attributes should be defined in __init__, not as class attributes: "
f"{', '.join(class_level_checker.complex_attributes)}"
)
if class_level_checker.non_defaults:
errors.append(
f"Parameters in __init__ must have default values, found required parameters: "
f"{', '.join(class_level_checker.non_defaults)}"
)
if class_level_checker.non_literal_defaults:
errors.append(
f"Parameters in __init__ must have literal default values, found non-literal defaults: "
f"{', '.join(class_level_checker.non_literal_defaults)}"
)

# Run checks on all methods
for node in class_node.body:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_all_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setup_class(cls):

load_dotenv()

cls.md_files = list(cls.docs_dir.rglob("*.md"))
cls.md_files = list(cls.docs_dir.rglob("*.mdx"))
if not cls.md_files:
raise ValueError(f"No markdown files found in {cls.docs_dir}")

Expand Down
99 changes: 96 additions & 3 deletions tests/test_tool_validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,102 @@
import pytest

from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool
from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool
from smolagents.tool_validation import validate_tool_attributes
from smolagents.tools import Tool


@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool])
def test_validate_tool_attributes(tool_class):
UNDEFINED_VARIABLE = "undefined_variable"


@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool])
def test_validate_tool_attributes_with_default_tools(tool_class):
assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool"


class ValidTool(Tool):
name = "valid_tool"
description = "A valid tool"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
simple_attr = "string"
dict_attr = {"key": "value"}

def __init__(self, optional_param="default"):
super().__init__()
self.param = optional_param

def forward(self, input: str) -> str:
return input.upper()


def test_validate_tool_attributes_valid():
assert validate_tool_attributes(ValidTool) is None


class InvalidToolComplexAttrs(Tool):
name = "invalid_tool"
description = "Tool with complex class attributes"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
complex_attr = [x for x in range(3)] # Complex class attribute

def __init__(self):
super().__init__()

def forward(self, input: str) -> str:
return input


class InvalidToolRequiredParams(Tool):
name = "invalid_tool"
description = "Tool with required params"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"

def __init__(self, required_param, kwarg1=1): # No default value
super().__init__()
self.param = required_param

def forward(self, input: str) -> str:
return input


class InvalidToolNonLiteralDefaultParam(Tool):
name = "invalid_tool"
description = "Tool with non-literal default parameter value"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"

def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal
super().__init__()
self.default_param = default_param

def forward(self, input: str) -> str:
return input


class InvalidToolUndefinedNames(Tool):
name = "invalid_tool"
description = "Tool with undefined names"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"

def forward(self, input: str) -> str:
return UNDEFINED_VARIABLE # Undefined name


@pytest.mark.parametrize(
"tool_class, expected_error",
[
(InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"),
(InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"),
(
InvalidToolNonLiteralDefaultParam,
"Parameters in __init__ must have literal default values, found non-literal defaults",
),
(InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"),
],
)
def test_validate_tool_attributes_exceptions(tool_class, expected_error):
with pytest.raises(ValueError, match=expected_error):
validate_tool_attributes(tool_class)
2 changes: 1 addition & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def forward(self, string_input: str) -> str:
fail_tool = FailTool("dummy_url")
with pytest.raises(Exception) as e:
fail_tool.to_dict()
assert "All parameters of __init__ must have default values!" in str(e)
assert "Parameters in __init__ must have default values, found required parameters" in str(e)

class PassTool(Tool):
name = "specific"
Expand Down

0 comments on commit d2a734d

Please sign in to comment.