Skip to content

Commit

Permalink
Handle functions with *args and **kwargs (#106)
Browse files Browse the repository at this point in the history
* Add create_model_from_function

* Use function signature to determine param kind for FunctionCall creation

* Add tests for function schema with args, kwargs

* Fix type hints

* Add tests for args/kwargs with no type hint

* Remove type hint from test functions
  • Loading branch information
jackmpcollins authored Feb 18, 2024
1 parent 5ac0100 commit ac6c60b
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 19 deletions.
91 changes: 77 additions & 14 deletions src/magentic/chat_model/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,20 +215,46 @@ def serialize_args(self, value: BaseModelT) -> str:
return value.model_dump_json()


def create_model_from_function(func: Callable[..., Any]) -> type[BaseModel]:
"""Create a Pydantic model from a function signature."""
# https://github.com/pydantic/pydantic/issues/3585#issuecomment-1002745763
fields: dict[str, Any] = {}
for param in inspect.signature(func).parameters.values():
# *args
if param.kind is inspect.Parameter.VAR_POSITIONAL:
fields[param.name] = (
(
list[param.annotation] # type: ignore[name-defined]
if param.annotation != inspect._empty
else list[Any]
),
param.default if param.default != inspect._empty else [],
)
continue

# **kwargs
if param.kind is inspect.Parameter.VAR_KEYWORD:
fields[param.name] = (
dict[str, param.annotation] # type: ignore[name-defined]
if param.annotation != inspect._empty
else dict[str, Any],
param.default if param.default != inspect._empty else {},
)
continue

fields[param.name] = (
(param.annotation if param.annotation != inspect._empty else Any),
(param.default if param.default != inspect._empty else ...),
)
return create_model("FuncModel", **fields)


class FunctionCallFunctionSchema(BaseFunctionSchema[FunctionCall[T]], Generic[T]):
"""FunctionSchema for FunctionCall."""

def __init__(self, func: Callable[..., T]):
self._func = func
# https://github.com/pydantic/pydantic/issues/3585#issuecomment-1002745763
fields: dict[str, Any] = {
param.name: (
(param.annotation if param.annotation != inspect._empty else Any),
(param.default if param.default != inspect._empty else ...),
)
for param in inspect.signature(func).parameters.values()
}
self._model = create_model("FuncModel", **fields)
self._model = create_model_from_function(func)

@property
def name(self) -> str:
Expand All @@ -246,13 +272,50 @@ def parameters(self) -> dict[str, Any]:

def parse_args(self, arguments: Iterable[str]) -> FunctionCall[T]:
model = self._model.model_validate_json("".join(arguments))
args = {attr: getattr(model, attr) for attr in model.model_fields_set}
return FunctionCall(self._func, **args)
supplied_params = [
param
for param in inspect.signature(self._func).parameters.values()
if param.name in model.model_fields_set
]

args_positional_only = [
getattr(model, param.name)
for param in supplied_params
if param.kind == param.POSITIONAL_ONLY
]
args_positional_or_keyword = [
getattr(model, param.name)
for param in supplied_params
if param.kind == param.POSITIONAL_OR_KEYWORD
]
args_var_positional = [
arg
for param in supplied_params
if param.kind == param.VAR_POSITIONAL
for arg in getattr(model, param.name)
]
args_keyword_only = {
param.name: getattr(model, param.name)
for param in supplied_params
if param.kind == param.KEYWORD_ONLY
}
args_var_keyword = {
name: value
for param in supplied_params
if param.kind == param.VAR_KEYWORD
for name, value in getattr(model, param.name).items()
}
return FunctionCall(
self._func,
*args_positional_only,
*args_positional_or_keyword,
*args_var_positional,
**args_keyword_only,
**args_var_keyword,
)

def serialize_args(self, value: FunctionCall[T]) -> str:
return cast(
str, self._model(**value.arguments).model_dump_json(exclude_unset=True)
)
return self._model(**value.arguments).model_dump_json(exclude_unset=True)


def function_schema_for_type(type_: type[Any]) -> BaseFunctionSchema[Any]:
Expand Down
117 changes: 112 additions & 5 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,22 @@ def plus_default_value(a: int, b: int = 3) -> int:
return a + b


def plus_with_args(a: int, *args: int) -> int:
return a + sum(args)


def plus_with_args_no_type_hints(a, *args):
return a + sum(args)


def plus_with_kwargs(a: int, **kwargs: int) -> int:
return a + sum(kwargs.values())


def plus_with_kwargs_no_type_hints(a, **kwargs):
return a + sum(kwargs.values())


def plus_with_annotated(
a: Annotated[int, Field(description="First number")],
b: Annotated[int, Field(description="Second number")],
Expand Down Expand Up @@ -600,6 +616,81 @@ def plus_with_basemodel(a: IntModel, b: IntModel) -> IntModel:
},
},
),
(
plus_with_args,
{
"name": "plus_with_args",
"parameters": {
"type": "object",
"properties": {
"a": {"title": "A", "type": "integer"},
"args": {
"default": [],
"items": {"type": "integer"},
"title": "Args",
"type": "array",
},
},
"required": ["a"],
},
},
),
(
plus_with_args_no_type_hints,
{
"name": "plus_with_args_no_type_hints",
"parameters": {
"type": "object",
"properties": {
"a": {"title": "A"},
"args": {
"default": [],
"items": {},
"title": "Args",
"type": "array",
},
},
"required": ["a"],
},
},
),
(
plus_with_kwargs,
{
"name": "plus_with_kwargs",
"parameters": {
"type": "object",
"properties": {
"a": {"title": "A", "type": "integer"},
"kwargs": {
"additionalProperties": {"type": "integer"},
"default": {},
"title": "Kwargs",
"type": "object",
},
},
"required": ["a"],
},
},
),
(
plus_with_kwargs_no_type_hints,
{
"name": "plus_with_kwargs_no_type_hints",
"parameters": {
"type": "object",
"properties": {
"a": {"title": "A"},
"kwargs": {
"default": {},
"title": "Kwargs",
"type": "object",
},
},
"required": ["a"],
},
},
),
(
plus_with_annotated,
{
Expand Down Expand Up @@ -662,22 +753,38 @@ def test_function_call_function_schema_with_default_value():


function_call_function_schema_args_test_cases = [
(plus, '{"a": 1, "b": 2}', FunctionCall(plus, a=1, b=2)),
(plus, '{"a": 1, "b": 2}', FunctionCall(plus, 1, 2)),
(
plus_no_type_hints,
'{"a": 1, "b": 2}',
FunctionCall(plus_no_type_hints, a=1, b=2),
FunctionCall(plus_no_type_hints, 1, 2),
),
(plus_default_value, '{"a": 1}', FunctionCall(plus_default_value, 1)),
(plus_with_args, '{"a": 1, "args": [2, 3]}', FunctionCall(plus_with_args, 1, 2, 3)),
(
plus_with_args_no_type_hints,
'{"a": 1, "args": [2, 3]}',
FunctionCall(plus_with_args_no_type_hints, 1, 2, 3),
),
(
plus_with_kwargs,
'{"a": 1, "kwargs": {"b": 2, "c": 3}}',
FunctionCall(plus_with_kwargs, 1, b=2, c=3),
),
(
plus_with_kwargs_no_type_hints,
'{"a": 1, "kwargs": {"b": 2, "c": 3}}',
FunctionCall(plus_with_kwargs_no_type_hints, 1, b=2, c=3),
),
(plus_default_value, '{"a": 1}', FunctionCall(plus_default_value, a=1)),
(
plus_with_annotated,
'{"a": 1, "b": 2}',
FunctionCall(plus_with_annotated, a=1, b=2),
FunctionCall(plus_with_annotated, 1, 2),
),
(
plus_with_basemodel,
'{"a": {"value": 1}, "b": {"value": 2}}',
FunctionCall(plus_with_basemodel, a=IntModel(value=1), b=IntModel(value=2)),
FunctionCall(plus_with_basemodel, IntModel(value=1), IntModel(value=2)),
),
]

Expand Down

0 comments on commit ac6c60b

Please sign in to comment.