Skip to content

Commit 24088d3

Browse files
committed
improve the concurrency of event handling
1 parent 7f1672a commit 24088d3

File tree

2 files changed

+121
-8
lines changed

2 files changed

+121
-8
lines changed

src/agents/agent.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import inspect
66
from collections.abc import Awaitable
77
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
8+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeAlias, cast
99

1010
from openai.types.responses.response_prompt_param import ResponsePromptParam
11-
from typing_extensions import NotRequired, TypeAlias, TypedDict
11+
from typing_extensions import NotRequired, TypedDict
1212

1313
from .agent_output import AgentOutputSchemaBase
1414
from .guardrail import InputGuardrail, OutputGuardrail
@@ -457,12 +457,12 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
457457
conversation_id=conversation_id,
458458
session=session,
459459
)
460-
async for event in run_result.stream_events():
461-
payload: AgentToolStreamEvent = {
462-
"event": event,
463-
"agent": self,
464-
"tool_call": getattr(context, "tool_call", None),
465-
}
460+
# Dispatch callbacks in the background so slow handlers do not block
461+
# event consumption.
462+
event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue()
463+
464+
async def _run_handler(payload: AgentToolStreamEvent) -> None:
465+
"""Execute the user callback while capturing exceptions."""
466466
try:
467467
maybe_result = on_stream(payload)
468468
if inspect.isawaitable(maybe_result):
@@ -472,6 +472,34 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
472472
"Error while handling on_stream event for agent tool %s.",
473473
self.name,
474474
)
475+
476+
async def dispatch_stream_events() -> None:
477+
while True:
478+
payload = await event_queue.get()
479+
is_sentinel = payload is None # None marks the end of the stream.
480+
try:
481+
if payload is not None:
482+
await _run_handler(payload)
483+
finally:
484+
event_queue.task_done()
485+
486+
if is_sentinel:
487+
break
488+
489+
dispatch_task = asyncio.create_task(dispatch_stream_events())
490+
491+
try:
492+
async for event in run_result.stream_events():
493+
payload: AgentToolStreamEvent = {
494+
"event": event,
495+
"agent": self,
496+
"tool_call": getattr(context, "tool_call", None),
497+
}
498+
await event_queue.put(payload)
499+
finally:
500+
await event_queue.put(None)
501+
await event_queue.join()
502+
await dispatch_task
475503
else:
476504
run_result = await Runner.run(
477505
starting_agent=self,

tests/test_agent_as_tool.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from typing import Any, cast
45

56
import pytest
@@ -612,6 +613,90 @@ def sync_handler(event: AgentToolStreamEvent) -> None:
612613
assert calls == ["raw_response_event"]
613614

614615

616+
@pytest.mark.asyncio
617+
async def test_agent_as_tool_streaming_dispatches_without_blocking(
618+
monkeypatch: pytest.MonkeyPatch,
619+
) -> None:
620+
"""on_stream handlers should not block streaming iteration."""
621+
agent = Agent(name="nonblocking_agent")
622+
623+
first_handler_started = asyncio.Event()
624+
allow_handler_to_continue = asyncio.Event()
625+
second_event_yielded = asyncio.Event()
626+
second_event_handled = asyncio.Event()
627+
628+
first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
629+
second_event = RawResponsesStreamEvent(
630+
data=cast(Any, {"type": "output_text_delta", "delta": "hi"})
631+
)
632+
633+
class DummyStreamingResult:
634+
def __init__(self) -> None:
635+
self.final_output = "ok"
636+
637+
async def stream_events(self):
638+
yield first_event
639+
second_event_yielded.set()
640+
yield second_event
641+
642+
dummy_result = DummyStreamingResult()
643+
644+
monkeypatch.setattr(
645+
Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result)
646+
)
647+
monkeypatch.setattr(
648+
Runner,
649+
"run",
650+
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))),
651+
)
652+
653+
async def on_stream(payload: AgentToolStreamEvent) -> None:
654+
if payload["event"] is first_event:
655+
first_handler_started.set()
656+
await allow_handler_to_continue.wait()
657+
else:
658+
second_event_handled.set()
659+
660+
tool_call = ResponseFunctionToolCall(
661+
id="call_nonblocking",
662+
arguments='{"input": "go"}',
663+
call_id="call-nonblocking",
664+
name="nonblocking_tool",
665+
type="function_call",
666+
)
667+
668+
tool = cast(
669+
FunctionTool,
670+
agent.as_tool(
671+
tool_name="nonblocking_tool",
672+
tool_description="Uses non-blocking streaming handler",
673+
on_stream=on_stream,
674+
),
675+
)
676+
tool_context = ToolContext(
677+
context=None,
678+
tool_name="nonblocking_tool",
679+
tool_call_id=tool_call.call_id,
680+
tool_arguments=tool_call.arguments,
681+
tool_call=tool_call,
682+
)
683+
684+
async def _invoke_tool() -> Any:
685+
return await tool.on_invoke_tool(tool_context, '{"input": "go"}')
686+
687+
invoke_task: asyncio.Task[Any] = asyncio.create_task(_invoke_tool())
688+
689+
await asyncio.wait_for(first_handler_started.wait(), timeout=1.0)
690+
await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0)
691+
assert invoke_task.done() is False
692+
693+
allow_handler_to_continue.set()
694+
await asyncio.wait_for(second_event_handled.wait(), timeout=1.0)
695+
output = await asyncio.wait_for(invoke_task, timeout=1.0)
696+
697+
assert output == "ok"
698+
699+
615700
@pytest.mark.asyncio
616701
async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call(
617702
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)