Skip to content

Commit a63c53b

Browse files
authored
Added support for "return" handoffs (#1)
1 parent 0eee6b8 commit a63c53b

File tree

5 files changed

+127
-15
lines changed

5 files changed

+127
-15
lines changed

src/agents/_run_impl.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ class NextStepHandoff:
176176
new_agent: Agent[Any]
177177

178178

179+
@dataclass
180+
class NextStepHandoffReturnControl:
181+
previous_agent: Agent[Any]
182+
183+
179184
@dataclass
180185
class NextStepFinalOutput:
181186
output: Any
@@ -201,7 +206,9 @@ class SingleStepResult:
201206
new_step_items: list[RunItem]
202207
"""Items generated during this current step."""
203208

204-
next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
209+
next_step: (
210+
NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepHandoffReturnControl
211+
)
205212
"""The next step to take."""
206213

207214
@property
@@ -238,6 +245,7 @@ async def execute_tools_and_side_effects(
238245
hooks: RunHooks[TContext],
239246
context_wrapper: RunContextWrapper[TContext],
240247
run_config: RunConfig,
248+
previous_agents: list[Agent],
241249
) -> SingleStepResult:
242250
# Make a copy of the generated items
243251
pre_step_items = list(pre_step_items)
@@ -286,6 +294,7 @@ async def execute_tools_and_side_effects(
286294
hooks=hooks,
287295
context_wrapper=context_wrapper,
288296
run_config=run_config,
297+
previous_agents=previous_agents,
289298
)
290299

291300
# Next, we'll check if the tool use should result in a final output
@@ -316,6 +325,7 @@ async def execute_tools_and_side_effects(
316325
final_output=check_tool_use.final_output,
317326
hooks=hooks,
318327
context_wrapper=context_wrapper,
328+
previous_agents=previous_agents,
319329
)
320330

321331
# Now we can check if the model also produced a final output
@@ -340,6 +350,7 @@ async def execute_tools_and_side_effects(
340350
final_output=final_output,
341351
hooks=hooks,
342352
context_wrapper=context_wrapper,
353+
previous_agents=previous_agents,
343354
)
344355
elif (
345356
not output_schema or output_schema.is_plain_text()
@@ -353,6 +364,7 @@ async def execute_tools_and_side_effects(
353364
final_output=potential_final_output_text or "",
354365
hooks=hooks,
355366
context_wrapper=context_wrapper,
367+
previous_agents=previous_agents,
356368
)
357369
else:
358370
# If there's no final output, we can just run again
@@ -663,6 +675,7 @@ async def execute_handoffs(
663675
hooks: RunHooks[TContext],
664676
context_wrapper: RunContextWrapper[TContext],
665677
run_config: RunConfig,
678+
previous_agents: list[Agent[TContext]],
666679
) -> SingleStepResult:
667680
# If there is more than one handoff, add tool responses that reject those handoffs
668681
multiple_handoffs = len(run_handoffs) > 1
@@ -684,6 +697,8 @@ async def execute_handoffs(
684697
actual_handoff = run_handoffs[0]
685698
with handoff_span(from_agent=agent.name) as span_handoff:
686699
handoff = actual_handoff.handoff
700+
if handoff.should_return_control:
701+
previous_agents.append(agent)
687702
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
688703
context_wrapper, actual_handoff.tool_call.arguments
689704
)
@@ -825,16 +840,21 @@ async def execute_final_output(
825840
final_output: Any,
826841
hooks: RunHooks[TContext],
827842
context_wrapper: RunContextWrapper[TContext],
843+
previous_agents: list[Agent[TContext]],
828844
) -> SingleStepResult:
845+
is_returning_control = len(previous_agents) > 0
829846
# Run the on_end hooks
830-
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)
831-
847+
await cls.run_final_output_hooks(
848+
agent, hooks, context_wrapper, final_output, is_returning_control
849+
)
832850
return SingleStepResult(
833851
original_input=original_input,
834852
model_response=new_response,
835853
pre_step_items=pre_step_items,
836854
new_step_items=new_step_items,
837-
next_step=NextStepFinalOutput(final_output),
855+
next_step=NextStepHandoffReturnControl(previous_agents.pop())
856+
if is_returning_control
857+
else NextStepFinalOutput(final_output),
838858
)
839859

840860
@classmethod
@@ -844,13 +864,19 @@ async def run_final_output_hooks(
844864
hooks: RunHooks[TContext],
845865
context_wrapper: RunContextWrapper[TContext],
846866
final_output: Any,
867+
is_returning_control: bool,
847868
):
848-
await asyncio.gather(
849-
hooks.on_agent_end(context_wrapper, agent, final_output),
850-
agent.hooks.on_end(context_wrapper, agent, final_output)
851-
if agent.hooks
852-
else _coro.noop_coroutine(),
853-
)
869+
# If the agent is not returning control, run the hooks
870+
if not is_returning_control:
871+
await asyncio.gather(
872+
hooks.on_agent_end(context_wrapper, agent, final_output),
873+
agent.hooks.on_end(context_wrapper, agent, final_output)
874+
if agent.hooks
875+
else _coro.noop_coroutine(),
876+
)
877+
# If the agent is returning control, only run the current agent's hooks
878+
elif agent.hooks:
879+
await agent.hooks.on_end(context_wrapper, agent, final_output)
854880

855881
@classmethod
856882
async def run_single_input_guardrail(

src/agents/handoffs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ class Handoff(Generic[TContext]):
9999
True, as it increases the likelihood of correct JSON input.
100100
"""
101101

102+
should_return_control: bool = False
103+
"""Whether the Agent that receives control during a handoff should return control to the
104+
original (previous) Agent upon completion of its work. If False, after the Agent that received
105+
the handoff completes its work, the interaction will end.
106+
"""
107+
102108
def get_transfer_message(self, agent: Agent[Any]) -> str:
103109
return json.dumps({"assistant": agent.name})
104110

@@ -121,6 +127,7 @@ def handoff(
121127
tool_name_override: str | None = None,
122128
tool_description_override: str | None = None,
123129
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130+
should_return_control: bool = False,
124131
) -> Handoff[TContext]: ...
125132

126133

@@ -133,6 +140,7 @@ def handoff(
133140
tool_description_override: str | None = None,
134141
tool_name_override: str | None = None,
135142
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143+
should_return_control: bool = False,
136144
) -> Handoff[TContext]: ...
137145

138146

@@ -144,6 +152,7 @@ def handoff(
144152
tool_description_override: str | None = None,
145153
tool_name_override: str | None = None,
146154
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155+
should_return_control: bool = False,
147156
) -> Handoff[TContext]: ...
148157

149158

@@ -154,6 +163,7 @@ def handoff(
154163
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
155164
input_type: type[THandoffInput] | None = None,
156165
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166+
should_return_control: bool = False,
157167
) -> Handoff[TContext]:
158168
"""Create a handoff from an agent.
159169
@@ -168,7 +178,7 @@ def handoff(
168178
input_filter: a function that filters the inputs that are passed to the next agent.
169179
"""
170180
assert (on_handoff and input_type) or not (on_handoff and input_type), (
171-
"You must provide either both on_handoff and input_type, or neither"
181+
"You must provide either both on_input and input_type, or neither"
172182
)
173183
type_adapter: TypeAdapter[Any] | None
174184
if input_type is not None:
@@ -233,4 +243,5 @@ async def _invoke_handoff(
233243
on_invoke_handoff=_invoke_handoff,
234244
input_filter=input_filter,
235245
agent_name=agent.name,
246+
should_return_control=should_return_control,
236247
)

src/agents/run.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AgentToolUseTracker,
1212
NextStepFinalOutput,
1313
NextStepHandoff,
14+
NextStepHandoffReturnControl,
1415
NextStepRunAgain,
1516
QueueCompleteSentinel,
1617
RunImpl,
@@ -119,6 +120,7 @@ async def run(
119120
hooks: RunHooks[TContext] | None = None,
120121
run_config: RunConfig | None = None,
121122
previous_response_id: str | None = None,
123+
previous_agents: list[Agent[TContext]] | None = None,
122124
) -> RunResult:
123125
"""Run a workflow starting at the given agent. The agent will run in a loop until a final
124126
output is generated. The loop runs like so:
@@ -154,6 +156,8 @@ async def run(
154156
hooks = RunHooks[Any]()
155157
if run_config is None:
156158
run_config = RunConfig()
159+
if previous_agents is None:
160+
previous_agents = []
157161

158162
tool_use_tracker = AgentToolUseTracker()
159163

@@ -235,6 +239,7 @@ async def run(
235239
should_run_agent_start_hooks=should_run_agent_start_hooks,
236240
tool_use_tracker=tool_use_tracker,
237241
previous_response_id=previous_response_id,
242+
previous_agents=previous_agents,
238243
),
239244
)
240245
else:
@@ -249,6 +254,7 @@ async def run(
249254
should_run_agent_start_hooks=should_run_agent_start_hooks,
250255
tool_use_tracker=tool_use_tracker,
251256
previous_response_id=previous_response_id,
257+
previous_agents=previous_agents,
252258
)
253259
should_run_agent_start_hooks = False
254260

@@ -273,8 +279,13 @@ async def run(
273279
output_guardrail_results=output_guardrail_results,
274280
context_wrapper=context_wrapper,
275281
)
276-
elif isinstance(turn_result.next_step, NextStepHandoff):
277-
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
282+
elif isinstance(turn_result.next_step, NextStepHandoff) or isinstance(
283+
turn_result.next_step, NextStepHandoffReturnControl
284+
):
285+
if isinstance(turn_result.next_step, NextStepHandoffReturnControl):
286+
current_agent = turn_result.next_step.previous_agent
287+
else:
288+
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
278289
current_span.finish(reset_current=True)
279290
current_span = None
280291
should_run_agent_start_hooks = True
@@ -367,6 +378,7 @@ def run_streamed(
367378
hooks: RunHooks[TContext] | None = None,
368379
run_config: RunConfig | None = None,
369380
previous_response_id: str | None = None,
381+
previous_agents: list[Agent[TContext]] | None = None,
370382
) -> RunResultStreaming:
371383
"""Run a workflow starting at the given agent in streaming mode. The returned result object
372384
contains a method you can use to stream semantic events as they are generated.
@@ -402,6 +414,8 @@ def run_streamed(
402414
hooks = RunHooks[Any]()
403415
if run_config is None:
404416
run_config = RunConfig()
417+
if previous_agents is None:
418+
previous_agents = []
405419

406420
# If there's already a trace, we don't create a new one. In addition, we can't end the
407421
# trace here, because the actual work is done in `stream_events` and this method ends
@@ -450,6 +464,7 @@ def run_streamed(
450464
context_wrapper=context_wrapper,
451465
run_config=run_config,
452466
previous_response_id=previous_response_id,
467+
previous_agents=previous_agents,
453468
)
454469
)
455470
return streamed_result
@@ -508,6 +523,7 @@ async def _run_streamed_impl(
508523
context_wrapper: RunContextWrapper[TContext],
509524
run_config: RunConfig,
510525
previous_response_id: str | None,
526+
previous_agents: list[Agent[TContext]],
511527
):
512528
if streamed_result.trace:
513529
streamed_result.trace.start(mark_as_current=True)
@@ -581,6 +597,7 @@ async def _run_streamed_impl(
581597
tool_use_tracker,
582598
all_tools,
583599
previous_response_id,
600+
previous_agents,
584601
)
585602
should_run_agent_start_hooks = False
586603

@@ -590,8 +607,14 @@ async def _run_streamed_impl(
590607
streamed_result.input = turn_result.original_input
591608
streamed_result.new_items = turn_result.generated_items
592609

593-
if isinstance(turn_result.next_step, NextStepHandoff):
594-
current_agent = turn_result.next_step.new_agent
610+
if isinstance(turn_result.next_step, NextStepHandoff) or isinstance(
611+
turn_result.next_step, NextStepHandoffReturnControl
612+
):
613+
if isinstance(turn_result.next_step, NextStepHandoff):
614+
current_agent = turn_result.next_step.new_agent
615+
else:
616+
current_agent = turn_result.next_step.previous_agent
617+
595618
current_span.finish(reset_current=True)
596619
current_span = None
597620
should_run_agent_start_hooks = True
@@ -666,6 +689,7 @@ async def _run_single_turn_streamed(
666689
tool_use_tracker: AgentToolUseTracker,
667690
all_tools: list[Tool],
668691
previous_response_id: str | None,
692+
previous_agents: list[Agent[TContext]],
669693
) -> SingleStepResult:
670694
if should_run_agent_start_hooks:
671695
await asyncio.gather(
@@ -746,6 +770,7 @@ async def _run_single_turn_streamed(
746770
context_wrapper=context_wrapper,
747771
run_config=run_config,
748772
tool_use_tracker=tool_use_tracker,
773+
previous_agents=previous_agents,
749774
)
750775

751776
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
@@ -765,6 +790,7 @@ async def _run_single_turn(
765790
should_run_agent_start_hooks: bool,
766791
tool_use_tracker: AgentToolUseTracker,
767792
previous_response_id: str | None,
793+
previous_agents: list[Agent[TContext]],
768794
) -> SingleStepResult:
769795
# Ensure we run the hooks before anything else
770796
if should_run_agent_start_hooks:
@@ -809,6 +835,7 @@ async def _run_single_turn(
809835
context_wrapper=context_wrapper,
810836
run_config=run_config,
811837
tool_use_tracker=tool_use_tracker,
838+
previous_agents=previous_agents,
812839
)
813840

814841
@classmethod
@@ -826,6 +853,7 @@ async def _get_single_step_result_from_response(
826853
context_wrapper: RunContextWrapper[TContext],
827854
run_config: RunConfig,
828855
tool_use_tracker: AgentToolUseTracker,
856+
previous_agents: list[Agent[TContext]],
829857
) -> SingleStepResult:
830858
processed_response = RunImpl.process_model_response(
831859
agent=agent,
@@ -847,6 +875,7 @@ async def _get_single_step_result_from_response(
847875
hooks=hooks,
848876
context_wrapper=context_wrapper,
849877
run_config=run_config,
878+
previous_agents=previous_agents,
850879
)
851880

852881
@classmethod

src/agents/tool_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
def _assert_must_pass_tool_call_id() -> str:
88
raise ValueError("tool_call_id must be passed to ToolContext")
99

10+
1011
@dataclass
1112
class ToolContext(RunContextWrapper[TContext]):
1213
"""The context of a tool call."""

0 commit comments

Comments
 (0)