Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stronger tracing tests with inline-snapshot #25

Merged
merged 7 commits into from
Mar 17, 2025
Merged
115 changes: 114 additions & 1 deletion tests/test_agent_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import asyncio

import pytest
from inline_snapshot import snapshot

from agents import Agent, RunConfig, Runner, trace

from .fake_model import FakeModel
from .test_responses import get_text_message
from .testing_processor import fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces


@pytest.mark.asyncio
Expand All @@ -25,6 +26,25 @@ async def test_single_run_is_single_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1, (
f"Got {len(spans)}, but expected 1: the agent span. data:"
Expand Down Expand Up @@ -52,6 +72,39 @@ async def test_multiple_runs_are_multiple_traces():
traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
]
)

spans = fetch_ordered_spans()
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"

Expand Down Expand Up @@ -79,6 +132,43 @@ async def test_wrapped_trace_is_single_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"

Expand All @@ -97,6 +187,8 @@ async def test_parent_disabled_trace_disabled_agent_trace():

traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])

spans = fetch_ordered_spans()
assert len(spans) == 0, (
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
Expand All @@ -116,6 +208,8 @@ async def test_manual_disabling_works():

traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])

spans = fetch_ordered_spans()
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"

Expand Down Expand Up @@ -164,6 +258,25 @@ async def test_not_starting_streaming_creates_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"

Expand Down
33 changes: 32 additions & 1 deletion tests/test_responses_tracing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from inline_snapshot import snapshot
from openai import AsyncOpenAI
from openai.types.responses import ResponseCompletedEvent

from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData
from tests import fake_model

from .testing_processor import fetch_ordered_spans
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans


class DummyTracing:
Expand Down Expand Up @@ -54,6 +55,15 @@ async def dummy_fetch_response(
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED
)

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id"}}],
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1

Expand Down Expand Up @@ -82,6 +92,10 @@ async def dummy_fetch_response(
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED_WITHOUT_DATA
)

assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)

spans = fetch_ordered_spans()
assert len(spans) == 1
assert spans[0].span_data.response is None
Expand All @@ -107,6 +121,8 @@ async def dummy_fetch_response(
"instr", "input", ModelSettings(), [], None, [], ModelTracing.DISABLED
)

assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])

spans = fetch_ordered_spans()
assert len(spans) == 0

Expand Down Expand Up @@ -139,6 +155,15 @@ async def __aiter__(self):
):
pass

assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id-123"}}],
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
Expand Down Expand Up @@ -174,6 +199,10 @@ async def __aiter__(self):
):
pass

assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)

spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
Expand Down Expand Up @@ -208,5 +237,7 @@ async def __aiter__(self):
):
pass

assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])

spans = fetch_ordered_spans()
assert len(spans) == 0
Loading