Skip to content

Commit b6aba55

Browse files
rm-openainiv-hertz
authored andcommitted
Add is_enabled to handoffs (#925)
Was added to function tools before, now handoffs. Towards #918
1 parent 1549dfc commit b6aba55

File tree

6 files changed

+140
-21
lines changed

6 files changed

+140
-21
lines changed

src/agents/handoffs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .strict_schema import ensure_strict_json_schema
1616
from .tracing.spans import SpanError
1717
from .util import _error_tracing, _json, _transforms
18+
from .util._types import MaybeAwaitable
1819

1920
if TYPE_CHECKING:
2021
from .agent import Agent
@@ -104,6 +105,11 @@ class Handoff(Generic[TContext]):
104105
original (previous) Agent upon completion of its work. If False, after the Agent that received
105106
the handoff completes its work, the interaction will end.
106107
"""
108+
109+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
110+
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
111+
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
112+
a handoff based on your context/state."""
107113

108114
def get_transfer_message(self, agent: Agent[Any]) -> str:
109115
return json.dumps({"assistant": agent.name})
@@ -128,6 +134,7 @@ def handoff(
128134
tool_description_override: str | None = None,
129135
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130136
should_return_control: bool = False,
137+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
131138
) -> Handoff[TContext]: ...
132139

133140

@@ -141,6 +148,7 @@ def handoff(
141148
tool_name_override: str | None = None,
142149
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143150
should_return_control: bool = False,
151+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
144152
) -> Handoff[TContext]: ...
145153

146154

@@ -153,6 +161,7 @@ def handoff(
153161
tool_name_override: str | None = None,
154162
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155163
should_return_control: bool = False,
164+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
156165
) -> Handoff[TContext]: ...
157166

158167

@@ -164,6 +173,7 @@ def handoff(
164173
input_type: type[THandoffInput] | None = None,
165174
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166175
should_return_control: bool = False,
176+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
167177
) -> Handoff[TContext]:
168178
"""Create a handoff from an agent.
169179
@@ -176,6 +186,9 @@ def handoff(
176186
input_type: the type of the input to the handoff. If provided, the input will be validated
177187
against this type. Only relevant if you pass a function that takes an input.
178188
input_filter: a function that filters the inputs that are passed to the next agent.
189+
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
190+
context and agent and returns whether the handoff is enabled. Disabled handoffs are
191+
hidden from the LLM at runtime.
179192
"""
180193
assert (on_handoff and input_type) or not (on_handoff and input_type), (
181194
"You must provide either both on_input and input_type, or neither"
@@ -244,4 +257,5 @@ async def _invoke_handoff(
244257
input_filter=input_filter,
245258
agent_name=agent.name,
246259
should_return_control=should_return_control,
260+
is_enabled=is_enabled,
247261
)

src/agents/run.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import copy
5+
import inspect
56
from dataclasses import dataclass, field
67
from typing import Any, Generic, cast
78

@@ -374,7 +375,8 @@ async def run(
374375
# agent changes, or if the agent loop ends.
375376
if current_span is None:
376377
handoff_names = [
377-
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
378+
h.agent_name
379+
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
378380
]
379381
if output_schema := AgentRunner._get_output_schema(current_agent):
380382
output_type_name = output_schema.name()
@@ -668,7 +670,10 @@ async def _start_streaming(
668670
# Start an agent span if we don't have one. This span is ended if the current
669671
# agent changes, or if the agent loop ends.
670672
if current_span is None:
671-
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
673+
handoff_names = [
674+
h.agent_name
675+
for h in await cls._get_handoffs(current_agent, context_wrapper)
676+
]
672677
if output_schema := cls._get_output_schema(current_agent):
673678
output_type_name = output_schema.name()
674679
else:
@@ -833,7 +838,7 @@ async def _run_single_turn_streamed(
833838
agent.get_prompt(context_wrapper),
834839
)
835840

836-
handoffs = cls._get_handoffs(agent)
841+
handoffs = await cls._get_handoffs(agent, context_wrapper)
837842
model = cls._get_model(agent, run_config)
838843
model_settings = agent.model_settings.resolve(run_config.model_settings)
839844
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
@@ -935,7 +940,7 @@ async def _run_single_turn(
935940
)
936941

937942
output_schema = cls._get_output_schema(agent)
938-
handoffs = cls._get_handoffs(agent)
943+
handoffs = await cls._get_handoffs(agent, context_wrapper)
939944
input = ItemHelpers.input_to_new_input_list(original_input)
940945
input.extend([generated_item.to_input_item() for generated_item in generated_items])
941946

@@ -1131,14 +1136,28 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
11311136
return AgentOutputSchema(agent.output_type)
11321137

11331138
@classmethod
1134-
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
1139+
async def _get_handoffs(
1140+
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
1141+
) -> list[Handoff]:
11351142
handoffs = []
11361143
for handoff_item in agent.handoffs:
11371144
if isinstance(handoff_item, Handoff):
11381145
handoffs.append(handoff_item)
11391146
elif isinstance(handoff_item, Agent):
11401147
handoffs.append(handoff(handoff_item))
1141-
return handoffs
1148+
1149+
async def _check_handoff_enabled(handoff_obj: Handoff) -> bool:
1150+
attr = handoff_obj.is_enabled
1151+
if isinstance(attr, bool):
1152+
return attr
1153+
res = attr(context_wrapper, agent)
1154+
if inspect.isawaitable(res):
1155+
return bool(await res)
1156+
return bool(res)
1157+
1158+
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
1159+
enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok]
1160+
return enabled
11421161

11431162
@classmethod
11441163
async def _get_all_tools(

tests/test_agent_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def test_handoff_with_agents():
4343
handoffs=[agent_1, agent_2],
4444
)
4545

46-
handoffs = AgentRunner._get_handoffs(agent_3)
46+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
4747
assert len(handoffs) == 2
4848

4949
assert handoffs[0].agent_name == "agent_1"
@@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj():
7878
],
7979
)
8080

81-
handoffs = AgentRunner._get_handoffs(agent_3)
81+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
8282
assert len(handoffs) == 2
8383

8484
assert handoffs[0].agent_name == "agent_1"
@@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
112112
handoffs=[handoff(agent_1), agent_2],
113113
)
114114

115-
handoffs = AgentRunner._get_handoffs(agent_3)
115+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
116116
assert len(handoffs) == 2
117117

118118
assert handoffs[0].agent_name == "agent_1"

tests/test_handoff_tool.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,26 @@ def get_len(data: HandoffInputData) -> int:
3838
return input_len + pre_handoff_len + new_items_len
3939

4040

41-
def test_single_handoff_setup():
41+
@pytest.mark.asyncio
42+
async def test_single_handoff_setup():
4243
agent_1 = Agent(name="test_1")
4344
agent_2 = Agent(name="test_2", handoffs=[agent_1])
4445

4546
assert not agent_1.handoffs
4647
assert agent_2.handoffs == [agent_1]
4748

48-
assert not AgentRunner._get_handoffs(agent_1)
49+
assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1)))
4950

50-
handoff_objects = AgentRunner._get_handoffs(agent_2)
51+
handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2))
5152
assert len(handoff_objects) == 1
5253
obj = handoff_objects[0]
5354
assert obj.tool_name == Handoff.default_tool_name(agent_1)
5455
assert obj.tool_description == Handoff.default_tool_description(agent_1)
5556
assert obj.agent_name == agent_1.name
5657

5758

58-
def test_multiple_handoffs_setup():
59+
@pytest.mark.asyncio
60+
async def test_multiple_handoffs_setup():
5961
agent_1 = Agent(name="test_1")
6062
agent_2 = Agent(name="test_2")
6163
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
@@ -64,7 +66,7 @@ def test_multiple_handoffs_setup():
6466
assert not agent_1.handoffs
6567
assert not agent_2.handoffs
6668

67-
handoff_objects = AgentRunner._get_handoffs(agent_3)
69+
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
6870
assert len(handoff_objects) == 2
6971
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
7072
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
@@ -76,7 +78,8 @@ def test_multiple_handoffs_setup():
7678
assert handoff_objects[1].agent_name == agent_2.name
7779

7880

79-
def test_custom_handoff_setup():
81+
@pytest.mark.asyncio
82+
async def test_custom_handoff_setup():
8083
agent_1 = Agent(name="test_1")
8184
agent_2 = Agent(name="test_2")
8285
agent_3 = Agent(
@@ -95,7 +98,7 @@ def test_custom_handoff_setup():
9598
assert not agent_1.handoffs
9699
assert not agent_2.handoffs
97100

98-
handoff_objects = AgentRunner._get_handoffs(agent_3)
101+
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
99102
assert len(handoff_objects) == 2
100103

101104
first_handoff = handoff_objects[0]
@@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None:
284287
obj = handoff(agent)
285288
transfer = obj.get_transfer_message(agent)
286289
assert json.loads(transfer) == {"assistant": agent.name}
290+
291+
292+
def test_handoff_is_enabled_bool():
293+
"""Test that handoff respects is_enabled boolean parameter."""
294+
agent = Agent(name="test")
295+
296+
# Test enabled handoff (default)
297+
handoff_enabled = handoff(agent)
298+
assert handoff_enabled.is_enabled is True
299+
300+
# Test explicitly enabled handoff
301+
handoff_explicit_enabled = handoff(agent, is_enabled=True)
302+
assert handoff_explicit_enabled.is_enabled is True
303+
304+
# Test disabled handoff
305+
handoff_disabled = handoff(agent, is_enabled=False)
306+
assert handoff_disabled.is_enabled is False
307+
308+
309+
@pytest.mark.asyncio
310+
async def test_handoff_is_enabled_callable():
311+
"""Test that handoff respects is_enabled callable parameter."""
312+
agent = Agent(name="test")
313+
314+
# Test callable that returns True
315+
def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
316+
return True
317+
318+
handoff_callable_enabled = handoff(agent, is_enabled=always_enabled)
319+
assert callable(handoff_callable_enabled.is_enabled)
320+
result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent)
321+
assert result is True
322+
323+
# Test callable that returns False
324+
def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
325+
return False
326+
327+
handoff_callable_disabled = handoff(agent, is_enabled=always_disabled)
328+
assert callable(handoff_callable_disabled.is_enabled)
329+
result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent)
330+
assert result is False
331+
332+
# Test async callable
333+
async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
334+
return True
335+
336+
handoff_async_enabled = handoff(agent, is_enabled=async_enabled)
337+
assert callable(handoff_async_enabled.is_enabled)
338+
result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore
339+
assert result is True
340+
341+
342+
@pytest.mark.asyncio
343+
async def test_handoff_is_enabled_filtering_integration():
344+
"""Integration test that disabled handoffs are filtered out by the runner."""
345+
346+
# Set up agents
347+
agent_1 = Agent(name="agent_1")
348+
agent_2 = Agent(name="agent_2")
349+
agent_3 = Agent(name="agent_3")
350+
351+
# Create main agent with mixed enabled/disabled handoffs
352+
main_agent = Agent(
353+
name="main_agent",
354+
handoffs=[
355+
handoff(agent_1, is_enabled=True), # enabled
356+
handoff(agent_2, is_enabled=False), # disabled
357+
handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable
358+
],
359+
)
360+
361+
context_wrapper = RunContextWrapper(main_agent)
362+
363+
# Get filtered handoffs using the runner's method
364+
filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper)
365+
366+
# Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out
367+
assert len(filtered_handoffs) == 2
368+
369+
# Check that the correct agents are present
370+
agent_names = {h.agent_name for h in filtered_handoffs}
371+
assert agent_names == {"agent_1", "agent_3"}
372+
assert "agent_2" not in agent_names

tests/test_run_step_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ async def get_execute_result(
369369
previous_agents: list[Agent[Any]] | None = None,
370370
) -> SingleStepResult:
371371
output_schema = AgentRunner._get_output_schema(agent)
372-
handoffs = AgentRunner._get_handoffs(agent)
373372
previous_agents = previous_agents if previous_agents is not None else []
373+
handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None))
374374

375375
processed_response = RunImpl.process_model_response(
376376
agent=agent,

tests/test_run_step_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
186186
agent=agent_3,
187187
response=response,
188188
output_schema=None,
189-
handoffs=AgentRunner._get_handoffs(agent_3),
189+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
190190
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
191191
)
192192
assert len(result.handoffs) == 1, "Should have a handoff here"
@@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
216216
agent=agent_3,
217217
response=response,
218218
output_schema=None,
219-
handoffs=AgentRunner._get_handoffs(agent_3),
219+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
220220
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
221221
)
222222

@@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
239239
agent=agent_3,
240240
response=response,
241241
output_schema=None,
242-
handoffs=AgentRunner._get_handoffs(agent_3),
242+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
243243
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
244244
)
245245
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
471471
agent=agent_3,
472472
response=response,
473473
output_schema=None,
474-
handoffs=AgentRunner._get_handoffs(agent_3),
474+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
475475
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
476476
)
477477
assert result.functions and len(result.functions) == 1

0 commit comments

Comments
 (0)