Skip to content

Revert "Add is_enabled to handoffs (#925)" #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .strict_schema import ensure_strict_json_schema
from .tracing.spans import SpanError
from .util import _error_tracing, _json, _transforms
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
from .agent import Agent
Expand Down Expand Up @@ -100,11 +99,6 @@ class Handoff(Generic[TContext]):
True, as it increases the likelihood of correct JSON input.
"""

is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
a handoff based on your context/state."""

def get_transfer_message(self, agent: Agent[Any]) -> str:
return json.dumps({"assistant": agent.name})

Expand All @@ -127,7 +121,6 @@ def handoff(
tool_name_override: str | None = None,
tool_description_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...


Expand All @@ -140,7 +133,6 @@ def handoff(
tool_description_override: str | None = None,
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...


Expand All @@ -152,7 +144,6 @@ def handoff(
tool_description_override: str | None = None,
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...


Expand All @@ -163,7 +154,6 @@ def handoff(
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
input_type: type[THandoffInput] | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]:
"""Create a handoff from an agent.

Expand All @@ -176,9 +166,6 @@ def handoff(
input_type: the type of the input to the handoff. If provided, the input will be validated
against this type. Only relevant if you pass a function that takes an input.
input_filter: a function that filters the inputs that are passed to the next agent.
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the handoff is enabled. Disabled handoffs are
hidden from the LLM at runtime.
"""
assert (on_handoff and input_type) or not (on_handoff and input_type), (
"You must provide either both on_handoff and input_type, or neither"
Expand Down Expand Up @@ -246,5 +233,4 @@ async def _invoke_handoff(
on_invoke_handoff=_invoke_handoff,
input_filter=input_filter,
agent_name=agent.name,
is_enabled=is_enabled,
)
31 changes: 6 additions & 25 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Generic, cast

Expand Down Expand Up @@ -362,8 +361,7 @@ async def run(
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [
h.agent_name
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
]
if output_schema := AgentRunner._get_output_schema(current_agent):
output_type_name = output_schema.name()
Expand Down Expand Up @@ -643,10 +641,7 @@ async def _start_streaming(
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [
h.agent_name
for h in await cls._get_handoffs(current_agent, context_wrapper)
]
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
if output_schema := cls._get_output_schema(current_agent):
output_type_name = output_schema.name()
else:
Expand Down Expand Up @@ -803,7 +798,7 @@ async def _run_single_turn_streamed(
agent.get_prompt(context_wrapper),
)

handoffs = await cls._get_handoffs(agent, context_wrapper)
handoffs = cls._get_handoffs(agent)
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
Expand Down Expand Up @@ -903,7 +898,7 @@ async def _run_single_turn(
)

output_schema = cls._get_output_schema(agent)
handoffs = await cls._get_handoffs(agent, context_wrapper)
handoffs = cls._get_handoffs(agent)
input = ItemHelpers.input_to_new_input_list(original_input)
input.extend([generated_item.to_input_item() for generated_item in generated_items])

Expand Down Expand Up @@ -1096,28 +1091,14 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
return AgentOutputSchema(agent.output_type)

@classmethod
async def _get_handoffs(
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
) -> list[Handoff]:
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
handoffs = []
for handoff_item in agent.handoffs:
if isinstance(handoff_item, Handoff):
handoffs.append(handoff_item)
elif isinstance(handoff_item, Agent):
handoffs.append(handoff(handoff_item))

async def _check_handoff_enabled(handoff_obj: Handoff) -> bool:
attr = handoff_obj.is_enabled
if isinstance(attr, bool):
return attr
res = attr(context_wrapper, agent)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok]
return enabled
return handoffs

@classmethod
async def _get_all_tools(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def test_handoff_with_agents():
handoffs=[agent_1, agent_2],
)

handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2

assert handoffs[0].agent_name == "agent_1"
Expand Down Expand Up @@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj():
],
)

handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2

assert handoffs[0].agent_name == "agent_1"
Expand Down Expand Up @@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
handoffs=[handoff(agent_1), agent_2],
)

handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2

assert handoffs[0].agent_name == "agent_1"
Expand Down
100 changes: 7 additions & 93 deletions tests/test_handoff_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,24 @@ def get_len(data: HandoffInputData) -> int:
return input_len + pre_handoff_len + new_items_len


@pytest.mark.asyncio
async def test_single_handoff_setup():
def test_single_handoff_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2", handoffs=[agent_1])

assert not agent_1.handoffs
assert agent_2.handoffs == [agent_1]

assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1)))
assert not AgentRunner._get_handoffs(agent_1)

handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2))
handoff_objects = AgentRunner._get_handoffs(agent_2)
assert len(handoff_objects) == 1
obj = handoff_objects[0]
assert obj.tool_name == Handoff.default_tool_name(agent_1)
assert obj.tool_description == Handoff.default_tool_description(agent_1)
assert obj.agent_name == agent_1.name


@pytest.mark.asyncio
async def test_multiple_handoffs_setup():
def test_multiple_handoffs_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
Expand All @@ -66,7 +64,7 @@ async def test_multiple_handoffs_setup():
assert not agent_1.handoffs
assert not agent_2.handoffs

handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
Expand All @@ -78,8 +76,7 @@ async def test_multiple_handoffs_setup():
assert handoff_objects[1].agent_name == agent_2.name


@pytest.mark.asyncio
async def test_custom_handoff_setup():
def test_custom_handoff_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(
Expand All @@ -98,7 +95,7 @@ async def test_custom_handoff_setup():
assert not agent_1.handoffs
assert not agent_2.handoffs

handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2

first_handoff = handoff_objects[0]
Expand Down Expand Up @@ -287,86 +284,3 @@ def test_get_transfer_message_is_valid_json() -> None:
obj = handoff(agent)
transfer = obj.get_transfer_message(agent)
assert json.loads(transfer) == {"assistant": agent.name}


def test_handoff_is_enabled_bool():
"""Test that handoff respects is_enabled boolean parameter."""
agent = Agent(name="test")

# Test enabled handoff (default)
handoff_enabled = handoff(agent)
assert handoff_enabled.is_enabled is True

# Test explicitly enabled handoff
handoff_explicit_enabled = handoff(agent, is_enabled=True)
assert handoff_explicit_enabled.is_enabled is True

# Test disabled handoff
handoff_disabled = handoff(agent, is_enabled=False)
assert handoff_disabled.is_enabled is False


@pytest.mark.asyncio
async def test_handoff_is_enabled_callable():
"""Test that handoff respects is_enabled callable parameter."""
agent = Agent(name="test")

# Test callable that returns True
def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
return True

handoff_callable_enabled = handoff(agent, is_enabled=always_enabled)
assert callable(handoff_callable_enabled.is_enabled)
result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent)
assert result is True

# Test callable that returns False
def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
return False

handoff_callable_disabled = handoff(agent, is_enabled=always_disabled)
assert callable(handoff_callable_disabled.is_enabled)
result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent)
assert result is False

# Test async callable
async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
return True

handoff_async_enabled = handoff(agent, is_enabled=async_enabled)
assert callable(handoff_async_enabled.is_enabled)
result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore
assert result is True


@pytest.mark.asyncio
async def test_handoff_is_enabled_filtering_integration():
"""Integration test that disabled handoffs are filtered out by the runner."""

# Set up agents
agent_1 = Agent(name="agent_1")
agent_2 = Agent(name="agent_2")
agent_3 = Agent(name="agent_3")

# Create main agent with mixed enabled/disabled handoffs
main_agent = Agent(
name="main_agent",
handoffs=[
handoff(agent_1, is_enabled=True), # enabled
handoff(agent_2, is_enabled=False), # disabled
handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable
],
)

context_wrapper = RunContextWrapper(main_agent)

# Get filtered handoffs using the runner's method
filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper)

# Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out
assert len(filtered_handoffs) == 2

# Check that the correct agents are present
agent_names = {h.agent_name for h in filtered_handoffs}
assert agent_names == {"agent_1", "agent_3"}
assert "agent_2" not in agent_names
2 changes: 1 addition & 1 deletion tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ async def get_execute_result(
run_config: RunConfig | None = None,
) -> SingleStepResult:
output_schema = AgentRunner._get_output_schema(agent)
handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None))
handoffs = AgentRunner._get_handoffs(agent)

processed_response = RunImpl.process_model_response(
agent=agent,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_run_step_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
agent=agent_3,
response=response,
output_schema=None,
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 1, "Should have a handoff here"
Expand Down Expand Up @@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
agent=agent_3,
response=response,
output_schema=None,
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)

Expand All @@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
agent=agent_3,
response=response,
output_schema=None,
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
Expand Down Expand Up @@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
agent=agent_3,
response=response,
output_schema=None,
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert result.functions and len(result.functions) == 1
Expand Down