Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#94): handle special parameters (self/cls) correctly with context parameter #137

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,36 @@ def function_schema(
takes_context = False
filtered_params = []

if params:
first_name, first_param = params[0]
# Prefer the evaluated type hint if available
ann = type_hints.get(first_name, first_param.annotation)
# Helper function to check if a parameter is a special method parameter
def is_special_param(name: str) -> bool:
return name in ("self", "cls")

# Helper function to check if a parameter is a context parameter
def is_context_param(name: str, param: inspect.Parameter) -> bool:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
takes_context = True # Mark that the function takes context
else:
filtered_params.append((first_name, first_param))
else:
return origin is RunContextWrapper
return False

if params:
first_name, first_param = params[0]

# Handle special first parameter cases
if is_context_param(first_name, first_param):
takes_context = True
elif not is_special_param(first_name):
filtered_params.append((first_name, first_param))

# For parameters other than the first, raise error if any use RunContextWrapper.
# For parameters other than the first, handle special cases and context
for name, param in params[1:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
raise UserError(
f"RunContextWrapper param found at non-first position in function"
f" {func.__name__}"
)
filtered_params.append((name, param))
if is_context_param(name, param):
raise UserError(
f"RunContextWrapper param found at non-first position in function"
f" {func.__name__}"
)
if not is_special_param(name):
filtered_params.append((name, param))

# We will collect field definitions for create_model as a dict:
# field_name -> (type_annotation, default_value_or_Field(...))
Expand Down
84 changes: 84 additions & 0 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,87 @@ def func(**kwargs: dict[str, int]):
assert properties.get("kwargs").get("type") == "object"
# The values in the dict are integers.
assert properties.get("kwargs").get("additionalProperties").get("type") == "integer"


def test_context_with_special_params():
"""Test that context parameter works correctly with special parameters (self/cls)."""
class TestClass:
def instance_method_with_context(self, ctx: RunContextWrapper[str], a: int) -> str:
return f"instance {a}"

@classmethod
def class_method_with_context(cls, ctx: RunContextWrapper[str], a: int) -> str:
return f"class {a}"

# Test instance method
instance = TestClass()
func_schema = function_schema(instance.instance_method_with_context)
assert func_schema.takes_context
assert func_schema.params_json_schema.get("title") == "instance_method_with_context_args"

# Verify only 'a' is in the schema, not 'self' or 'ctx'
properties = func_schema.params_json_schema.get("properties", {})
assert "a" in properties
assert "self" not in properties
assert "ctx" not in properties

# Test class method
func_schema = function_schema(TestClass.class_method_with_context)
assert func_schema.takes_context
assert func_schema.params_json_schema.get("title") == "class_method_with_context_args"

# Verify only 'a' is in the schema, not 'cls' or 'ctx'
properties = func_schema.params_json_schema.get("properties", {})
assert "a" in properties
assert "cls" not in properties
assert "ctx" not in properties

# Test actual function calls
context = RunContextWrapper(context="test")

# Test instance method call
parsed = func_schema.params_pydantic_model(**{"a": 42})
args, kwargs_dict = func_schema.to_call_args(parsed)
result = instance.instance_method_with_context(context, *args, **kwargs_dict)
assert result == "instance 42"

# Test class method call
parsed = func_schema.params_pydantic_model(**{"a": 42})
args, kwargs_dict = func_schema.to_call_args(parsed)
result = TestClass.class_method_with_context(context, *args, **kwargs_dict)
assert result == "class 42"


def test_context_with_other_params():
"""Test that context parameter works correctly with other parameters."""
def func_with_context_and_params(
ctx: RunContextWrapper[str],
a: int,
b: str = "default",
) -> str:
return f"{a} {b}"

func_schema = function_schema(func_with_context_and_params)
assert func_schema.takes_context
assert func_schema.params_json_schema.get("title") == "func_with_context_and_params_args"

# Verify schema only contains 'a' and 'b', not 'ctx'
properties = func_schema.params_json_schema.get("properties", {})
assert "a" in properties
assert "b" in properties
assert "ctx" not in properties

# Test function call
context = RunContextWrapper(context="test")

# Test with default value
parsed = func_schema.params_pydantic_model(**{"a": 42})
args, kwargs_dict = func_schema.to_call_args(parsed)
result = func_with_context_and_params(context, *args, **kwargs_dict)
assert result == "42 default"

# Test with explicit value
parsed = func_schema.params_pydantic_model(**{"a": 42, "b": "explicit"})
args, kwargs_dict = func_schema.to_call_args(parsed)
result = func_with_context_and_params(context, *args, **kwargs_dict)
assert result == "42 explicit"