Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pydantic
import websockets
from openai import AsyncOpenAI
from openai.types.realtime import realtime_audio_config as _rt_audio_config
from openai.types.realtime.conversation_item import (
ConversationItem,
Expand Down Expand Up @@ -81,6 +82,7 @@
from pydantic import Field, TypeAdapter
from typing_extensions import assert_never
from websockets.asyncio.client import ClientConnection
from websockets.typing import Subprotocol

from agents.handoffs import Handoff
from agents.prompts import Prompt
Expand Down Expand Up @@ -138,6 +140,7 @@


_USER_AGENT = f"Agents/Python {__version__}"
_SDK_CLIENT_META = f"openai-agents-sdk.python.{__version__}"

DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = {
"voice": "ash",
Expand Down Expand Up @@ -210,7 +213,6 @@ async def connect(self, options: RealtimeModelConfig) -> None:

self.model = model_settings.get("model_name", self.model)
api_key = await get_api_key(options.get("api_key"))

if "tracing" in model_settings:
self._tracing_config = model_settings["tracing"]
else:
Expand All @@ -219,24 +221,71 @@ async def connect(self, options: RealtimeModelConfig) -> None:
url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")

headers: dict[str, str] = {}
if options.get("headers") is not None:
subprotocols: list[Subprotocol] = [
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a fundamental change for the issue, but having these is generally recommended and aligning with TS SDK.

Subprotocol("realtime"),
Subprotocol(_SDK_CLIENT_META),
]

custom_headers = options.get("headers")
if custom_headers is not None:
# For customizing request headers
headers.update(options["headers"])
headers.update(custom_headers)
else:
# OpenAI's Realtime API
if not api_key:
raise UserError("API key is required but was not provided.")

headers.update({"Authorization": f"Bearer {api_key}"})
ephemeral_key: str | None
if api_key.startswith("ek_"):
ephemeral_key = api_key
else:
ephemeral_key = await self._maybe_create_client_secret(api_key, self.model)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating an ephemeral key for establishing a WS connection is the recommended way while both still work in terms of realtime capabilities. Going with an ephemeral key resolves the tracing feature issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using an ephemeral key for server <-> server communication seems like unnecessary overhead to me.


if ephemeral_key:
subprotocols = [
Subprotocol("realtime"),
Subprotocol(f"openai-insecure-api-key.{ephemeral_key}"),
Subprotocol(_SDK_CLIENT_META),
]
else:
headers["Authorization"] = f"Bearer {api_key}"

self._websocket = await websockets.connect(
url,
user_agent_header=_USER_AGENT,
additional_headers=headers,
subprotocols=tuple(subprotocols),
max_size=None, # Allow any size of message
)
self._websocket_task = asyncio.create_task(self._listen_for_messages())
await self._update_session_config(model_settings)

async def _maybe_create_client_secret(self, api_key: str, model_name: str) -> str | None:
try:
return await self._create_client_secret(api_key, model_name)
except Exception as exc:
logger.warning(
"Failed to create realtime client secret; using API key directly: %s",
exc,
)
return None

async def _create_client_secret(self, api_key: str, model_name: str) -> str:
client = AsyncOpenAI(api_key=api_key)
try:
secret = await client.realtime.client_secrets.create(
session={"type": "realtime", "model": model_name}
)
finally:
await client.close()

value = secret.value if isinstance(getattr(secret, "value", None), str) else None

if value is None:
raise UserError("Realtime client secret response did not include a value.")

return value

async def _send_tracing_config(
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
) -> None:
Expand Down
109 changes: 107 additions & 2 deletions tests/realtime/test_tracing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import cast
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock, Mock, patch

import pytest
Expand All @@ -8,11 +9,40 @@
from openai.types.realtime.realtime_tracing_config import TracingConfiguration

from agents.realtime.agent import RealtimeAgent
from agents.realtime.model import RealtimeModel
from agents.realtime.model import RealtimeModel, RealtimeModelConfig
from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel
from agents.realtime.session import RealtimeSession


@pytest.fixture(autouse=True)
def mock_client_secret_request(monkeypatch):
records: dict[str, list[dict[str, Any]]] = {"init_kwargs": [], "sessions": []}

class DummySecrets:
async def create(self, *, session: dict[str, Any]) -> SimpleNamespace:
records["sessions"].append(session)
return SimpleNamespace(value="ek_test")

class DummyRealtime:
def __init__(self):
self.client_secrets = DummySecrets()

class DummyClient:
def __init__(self, *args, **kwargs):
records["init_kwargs"].append(kwargs)
self.realtime = DummyRealtime()

async def close(self) -> None:
return None

monkeypatch.setattr(
"agents.realtime.openai_realtime.AsyncOpenAI",
DummyClient,
)

return records


class TestRealtimeTracingIntegration:
"""Test tracing configuration and session.update integration."""

Expand Down Expand Up @@ -62,6 +92,7 @@ async def async_websocket(*args, **kwargs):
"metadata": {"version": "1.0"},
}


# Test without tracing config - should default to "auto"
model2 = OpenAIRealtimeWebSocketModel()
config_no_tracing = {
Expand Down Expand Up @@ -251,3 +282,77 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket):

# When tracing is disabled, model settings should have tracing=None
assert model_settings["tracing"] is None

@pytest.mark.asyncio
async def test_connect_sets_sdk_headers_and_subprotocols(
self,
mock_websocket,
mock_client_secret_request,
):
"""Ensure websocket handshake mirrors Agents JS with client secrets."""
model = OpenAIRealtimeWebSocketModel()
config: RealtimeModelConfig = {
"api_key": "sk-test",
"initial_model_settings": {},
}

captured_kwargs: dict[str, Any] = {}

async def async_websocket(*args, **kwargs):
captured_kwargs.update(kwargs)
return mock_websocket

with patch("websockets.connect", side_effect=async_websocket):
with patch("asyncio.create_task") as mock_create_task:
mock_task = AsyncMock()
mock_create_task.return_value = mock_task
mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1]

await model.connect(config)

headers = captured_kwargs["additional_headers"]
assert "Authorization" not in headers

subprotocols = captured_kwargs["subprotocols"]
assert subprotocols[0] == "realtime"
assert subprotocols[1].startswith("openai-insecure-api-key.ek_test")
assert subprotocols[2].startswith("openai-agents-sdk.python.")
# Ensure client secret API was called once
assert mock_client_secret_request["init_kwargs"] == [{"api_key": "sk-test"}]
assert mock_client_secret_request["sessions"] == [
{"type": "realtime", "model": "gpt-realtime"}
]

@pytest.mark.asyncio
async def test_connect_with_ephemeral_key_skips_client_secret(
self,
mock_websocket,
mock_client_secret_request,
):
"""Ensure pre-generated ek_ keys are used directly without calling the API."""
model = OpenAIRealtimeWebSocketModel()
config: RealtimeModelConfig = {
"api_key": "ek_existing",
"initial_model_settings": {},
}

captured_kwargs: dict[str, Any] = {}

async def async_websocket(*args, **kwargs):
captured_kwargs.update(kwargs)
return mock_websocket

with patch("websockets.connect", side_effect=async_websocket):
with patch("asyncio.create_task") as mock_create_task:
mock_task = AsyncMock()
mock_create_task.return_value = mock_task
mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1]

await model.connect(config)

# No client secret API calls should have been made
assert mock_client_secret_request["init_kwargs"] == []
assert mock_client_secret_request["sessions"] == []

subprotocols = captured_kwargs["subprotocols"]
assert subprotocols[1] == "openai-insecure-api-key.ek_existing"