-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'huggingface:main' into main
- Loading branch information
Showing
7 changed files
with
132 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,102 @@ | ||
import pytest | ||
|
||
from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool | ||
from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool | ||
from smolagents.tool_validation import validate_tool_attributes | ||
from smolagents.tools import Tool | ||
|
||
|
||
@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool]) | ||
def test_validate_tool_attributes(tool_class): | ||
UNDEFINED_VARIABLE = "undefined_variable" | ||
|
||
|
||
@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool]) | ||
def test_validate_tool_attributes_with_default_tools(tool_class): | ||
assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool" | ||
|
||
|
||
class ValidTool(Tool): | ||
name = "valid_tool" | ||
description = "A valid tool" | ||
inputs = {"input": {"type": "string", "description": "input"}} | ||
output_type = "string" | ||
simple_attr = "string" | ||
dict_attr = {"key": "value"} | ||
|
||
def __init__(self, optional_param="default"): | ||
super().__init__() | ||
self.param = optional_param | ||
|
||
def forward(self, input: str) -> str: | ||
return input.upper() | ||
|
||
|
||
def test_validate_tool_attributes_valid(): | ||
assert validate_tool_attributes(ValidTool) is None | ||
|
||
|
||
class InvalidToolComplexAttrs(Tool): | ||
name = "invalid_tool" | ||
description = "Tool with complex class attributes" | ||
inputs = {"input": {"type": "string", "description": "input"}} | ||
output_type = "string" | ||
complex_attr = [x for x in range(3)] # Complex class attribute | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input: str) -> str: | ||
return input | ||
|
||
|
||
class InvalidToolRequiredParams(Tool): | ||
name = "invalid_tool" | ||
description = "Tool with required params" | ||
inputs = {"input": {"type": "string", "description": "input"}} | ||
output_type = "string" | ||
|
||
def __init__(self, required_param, kwarg1=1): # No default value | ||
super().__init__() | ||
self.param = required_param | ||
|
||
def forward(self, input: str) -> str: | ||
return input | ||
|
||
|
||
class InvalidToolNonLiteralDefaultParam(Tool): | ||
name = "invalid_tool" | ||
description = "Tool with non-literal default parameter value" | ||
inputs = {"input": {"type": "string", "description": "input"}} | ||
output_type = "string" | ||
|
||
def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal | ||
super().__init__() | ||
self.default_param = default_param | ||
|
||
def forward(self, input: str) -> str: | ||
return input | ||
|
||
|
||
class InvalidToolUndefinedNames(Tool): | ||
name = "invalid_tool" | ||
description = "Tool with undefined names" | ||
inputs = {"input": {"type": "string", "description": "input"}} | ||
output_type = "string" | ||
|
||
def forward(self, input: str) -> str: | ||
return UNDEFINED_VARIABLE # Undefined name | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tool_class, expected_error", | ||
[ | ||
(InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"), | ||
(InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"), | ||
( | ||
InvalidToolNonLiteralDefaultParam, | ||
"Parameters in __init__ must have literal default values, found non-literal defaults", | ||
), | ||
(InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"), | ||
], | ||
) | ||
def test_validate_tool_attributes_exceptions(tool_class, expected_error): | ||
with pytest.raises(ValueError, match=expected_error): | ||
validate_tool_attributes(tool_class) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters