Skip to content

Add logprobs #971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/ja/models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ OpenAI の Responses API を使用する場合、`user` や `service_tier` な
```python
from agents import Agent, ModelSettings

english_agent = Agent(
english_agent = Agent(
name="English agent",
instructions="You only speak English",
model="gpt-4o",
Expand All @@ -114,6 +114,20 @@ english_agent = Agent(
)
```

Responses API でトークンの対数確率を取得したい場合は、
`ModelSettings` の `top_logprobs` を設定してください。

```python
from agents import Agent, ModelSettings

agent = Agent(
name="English agent",
instructions="You only speak English",
model="gpt-4o",
model_settings=ModelSettings(top_logprobs=2),
)
```

## 他の LLM プロバイダー使用時の一般的な問題

### Tracing クライアントの 401 エラー
Expand Down
14 changes: 14 additions & 0 deletions docs/models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ english_agent = Agent(
)
```

You can also request token log probabilities when using the Responses API by
setting `top_logprobs` in `ModelSettings`.

```python
from agents import Agent, ModelSettings

agent = Agent(
name="English agent",
instructions="You only speak English",
model="gpt-4o",
model_settings=ModelSettings(top_logprobs=2),
)
```

## Common issues with using other LLM providers

### Tracing client error 401
Expand Down
17 changes: 11 additions & 6 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
class _OmitTypeAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_none(value: None) -> _Omit:
return _Omit()
Expand All @@ -39,13 +39,14 @@ def validate_from_none(value: None) -> _Omit:
from_none_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: None
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
)


Omit = Annotated[_Omit, _OmitTypeAnnotation]
Headers: TypeAlias = Mapping[str, Union[str, Omit]]


@dataclass
class ModelSettings:
"""Settings to use when calling an LLM.
Expand Down Expand Up @@ -107,6 +108,10 @@ class ModelSettings:
"""Additional output data to include in the model response.
[include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)"""

top_logprobs: int | None = None
"""Number of top tokens to return logprobs for. Setting this will
automatically include ``"message.output_text.logprobs"`` in the response."""

extra_query: Query | None = None
"""Additional query fields to provide with the request.
Defaults to None if not provided."""
Expand Down
15 changes: 11 additions & 4 deletions src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, cast, overload

from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
from openai.types import ChatModel
Expand Down Expand Up @@ -246,9 +246,12 @@ async def _fetch_response(
converted_tools = Converter.convert_tools(tools, handoffs)
response_format = Converter.get_response_format(output_schema)

include: list[ResponseIncludable] = converted_tools.includes
include_set: set[str] = set(converted_tools.includes)
if model_settings.response_include is not None:
include = list({*include, *model_settings.response_include})
include_set.update(model_settings.response_include)
if model_settings.top_logprobs is not None:
include_set.add("message.output_text.logprobs")
include = cast(list[ResponseIncludable], list(include_set))

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Calling LLM")
Expand All @@ -263,6 +266,10 @@ async def _fetch_response(
f"Previous response id: {previous_response_id}\n"
)

extra_args = dict(model_settings.extra_args or {})
if model_settings.top_logprobs is not None:
extra_args["top_logprobs"] = model_settings.top_logprobs

return await self._client.responses.create(
previous_response_id=self._non_null_or_not_given(previous_response_id),
instructions=self._non_null_or_not_given(system_instructions),
Expand All @@ -285,7 +292,7 @@ async def _fetch_response(
store=self._non_null_or_not_given(model_settings.store),
reasoning=self._non_null_or_not_given(model_settings.reasoning),
metadata=self._non_null_or_not_given(model_settings.metadata),
**(model_settings.extra_args or {}),
**extra_args,
)

def _get_client(self) -> AsyncOpenAI:
Expand Down
4 changes: 3 additions & 1 deletion tests/model_settings/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_all_fields_serialization() -> None:
store=False,
include_usage=False,
response_include=["reasoning.encrypted_content"],
top_logprobs=1,
extra_query={"foo": "bar"},
extra_body={"foo": "bar"},
extra_headers={"foo": "bar"},
Expand Down Expand Up @@ -135,8 +136,8 @@ def test_extra_args_resolve_both_none() -> None:
assert resolved.temperature == 0.5
assert resolved.top_p == 0.9

def test_pydantic_serialization() -> None:

def test_pydantic_serialization() -> None:
"""Tests whether ModelSettings can be serialized with Pydantic."""

# First, lets create a ModelSettings instance
Expand All @@ -153,6 +154,7 @@ def test_pydantic_serialization() -> None:
metadata={"foo": "bar"},
store=False,
include_usage=False,
top_logprobs=1,
extra_query={"foo": "bar"},
extra_body={"foo": "bar"},
extra_headers={"foo": "bar"},
Expand Down
50 changes: 50 additions & 0 deletions tests/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents import ModelSettings, ModelTracing, OpenAIResponsesModel


class DummyResponses:
async def create(self, **kwargs):
self.kwargs = kwargs

class DummyResponse:
id = "dummy"
output = []
usage = type(
"Usage",
(),
{
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"input_tokens_details": InputTokensDetails(cached_tokens=0),
"output_tokens_details": OutputTokensDetails(reasoning_tokens=0),
},
)()

return DummyResponse()


class DummyClient:
def __init__(self):
self.responses = DummyResponses()


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_top_logprobs_param_passed():
client = DummyClient()
model = OpenAIResponsesModel(model="gpt-4", openai_client=client) # type: ignore
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(top_logprobs=2),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert client.responses.kwargs["top_logprobs"] == 2
assert "message.output_text.logprobs" in client.responses.kwargs["include"]