|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | from typing import Any, cast |
4 | 5 |
|
5 | 6 | import pytest |
@@ -612,6 +613,90 @@ def sync_handler(event: AgentToolStreamEvent) -> None: |
612 | 613 | assert calls == ["raw_response_event"] |
613 | 614 |
|
614 | 615 |
|
| 616 | +@pytest.mark.asyncio |
| 617 | +async def test_agent_as_tool_streaming_dispatches_without_blocking( |
| 618 | + monkeypatch: pytest.MonkeyPatch, |
| 619 | +) -> None: |
| 620 | + """on_stream handlers should not block streaming iteration.""" |
| 621 | + agent = Agent(name="nonblocking_agent") |
| 622 | + |
| 623 | + first_handler_started = asyncio.Event() |
| 624 | + allow_handler_to_continue = asyncio.Event() |
| 625 | + second_event_yielded = asyncio.Event() |
| 626 | + second_event_handled = asyncio.Event() |
| 627 | + |
| 628 | + first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) |
| 629 | + second_event = RawResponsesStreamEvent( |
| 630 | + data=cast(Any, {"type": "output_text_delta", "delta": "hi"}) |
| 631 | + ) |
| 632 | + |
| 633 | + class DummyStreamingResult: |
| 634 | + def __init__(self) -> None: |
| 635 | + self.final_output = "ok" |
| 636 | + |
| 637 | + async def stream_events(self): |
| 638 | + yield first_event |
| 639 | + second_event_yielded.set() |
| 640 | + yield second_event |
| 641 | + |
| 642 | + dummy_result = DummyStreamingResult() |
| 643 | + |
| 644 | + monkeypatch.setattr( |
| 645 | + Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result) |
| 646 | + ) |
| 647 | + monkeypatch.setattr( |
| 648 | + Runner, |
| 649 | + "run", |
| 650 | + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), |
| 651 | + ) |
| 652 | + |
| 653 | + async def on_stream(payload: AgentToolStreamEvent) -> None: |
| 654 | + if payload["event"] is first_event: |
| 655 | + first_handler_started.set() |
| 656 | + await allow_handler_to_continue.wait() |
| 657 | + else: |
| 658 | + second_event_handled.set() |
| 659 | + |
| 660 | + tool_call = ResponseFunctionToolCall( |
| 661 | + id="call_nonblocking", |
| 662 | + arguments='{"input": "go"}', |
| 663 | + call_id="call-nonblocking", |
| 664 | + name="nonblocking_tool", |
| 665 | + type="function_call", |
| 666 | + ) |
| 667 | + |
| 668 | + tool = cast( |
| 669 | + FunctionTool, |
| 670 | + agent.as_tool( |
| 671 | + tool_name="nonblocking_tool", |
| 672 | + tool_description="Uses non-blocking streaming handler", |
| 673 | + on_stream=on_stream, |
| 674 | + ), |
| 675 | + ) |
| 676 | + tool_context = ToolContext( |
| 677 | + context=None, |
| 678 | + tool_name="nonblocking_tool", |
| 679 | + tool_call_id=tool_call.call_id, |
| 680 | + tool_arguments=tool_call.arguments, |
| 681 | + tool_call=tool_call, |
| 682 | + ) |
| 683 | + |
| 684 | + async def _invoke_tool() -> Any: |
| 685 | + return await tool.on_invoke_tool(tool_context, '{"input": "go"}') |
| 686 | + |
| 687 | + invoke_task: asyncio.Task[Any] = asyncio.create_task(_invoke_tool()) |
| 688 | + |
| 689 | + await asyncio.wait_for(first_handler_started.wait(), timeout=1.0) |
| 690 | + await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0) |
| 691 | + assert invoke_task.done() is False |
| 692 | + |
| 693 | + allow_handler_to_continue.set() |
| 694 | + await asyncio.wait_for(second_event_handled.wait(), timeout=1.0) |
| 695 | + output = await asyncio.wait_for(invoke_task, timeout=1.0) |
| 696 | + |
| 697 | + assert output == "ok" |
| 698 | + |
| 699 | + |
615 | 700 | @pytest.mark.asyncio |
616 | 701 | async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( |
617 | 702 | monkeypatch: pytest.MonkeyPatch, |
|
0 commit comments