Skip to content

Support composite arg type parsing in dspy.Tool #8095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import Any, Callable, Literal, get_origin
from typing import Any, Callable, Literal

from litellm import ContextWindowExceededError
from pydantic import BaseModel

import dspy
from dspy.primitives.program import Module
Expand Down Expand Up @@ -81,18 +80,7 @@ def forward(self, **input_args):
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
parsed_tool_args = {}
tool = self.tools[pred.next_tool_name]
for k, v in pred.next_tool_args.items():
if hasattr(tool, "arg_types") and k in tool.arg_types:
arg_type = tool.arg_types[k]
if isinstance((origin := get_origin(arg_type) or arg_type), type) and issubclass(
origin, BaseModel
):
parsed_tool_args[k] = arg_type.model_validate(v)
continue
parsed_tool_args[k] = v
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**parsed_tool_args)
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as e:
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"

Expand Down
19 changes: 17 additions & 2 deletions dspy/primitives/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Optional, get_origin, get_type_hints

from jsonschema import ValidationError, validate
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, TypeAdapter, create_model

from dspy.utils.callback import with_callbacks

Expand Down Expand Up @@ -130,6 +130,19 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
self.args = self.args or args
self.arg_types = self.arg_types or arg_types

def _parse_args(self, **kwargs):
parsed_kwargs = {}
for k, v in kwargs.items():
if k in self.arg_types and self.arg_types[k] != any:
# Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type.
# This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]`
pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...))
parsed = pydantic_wrapper.model_validate({"value": v})
parsed_kwargs[k] = parsed.value
else:
parsed_kwargs[k] = v
return parsed_kwargs

@with_callbacks
def __call__(self, **kwargs):
for k, v in kwargs.items():
Expand All @@ -142,4 +155,6 @@ def __call__(self, **kwargs):
validate(instance=instance, schema=self.args[k])
except ValidationError as e:
raise ValueError(f"Arg {k} is invalid: {e.message}")
return self.func(**kwargs)

parsed_kwargs = self._parse_args(**kwargs)
return self.func(**parsed_kwargs)
34 changes: 34 additions & 0 deletions tests/primitives/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,37 @@ def foo(x=100):
tool = Tool(foo)
assert tool.args["x"]["default"] == 100
assert not hasattr(tool.args["x"], "type")


def test_tool_call_parses_args():
Copy link
Collaborator

@TomeHirata TomeHirata Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we also add a unit test for the nested type like list[DummyModel]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call!

tool = Tool(dummy_with_pydantic)

args = {
"model": {
"field1": "hello",
"field2": 123,
}
}

result = tool(**args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: What are the types we officially support as tool-callable for ReAct? Python primitives + Pydantic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, everything supported by typing + pydantic

assert result == "hello 123"


def test_tool_call_parses_nested_list_of_pydantic_model():
def dummy_function(x: list[list[DummyModel]]):
return x

tool = Tool(dummy_function)
args = {
"x": [
[
{
"field1": "hello",
"field2": 123,
}
]
]
}

result = tool(**args)
assert result == [[DummyModel(field1="hello", field2=123)]]