Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions litellm/responses/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import contextvars
from functools import partial
from typing import Any, Coroutine, Dict, Iterable, List, Literal, Optional, Union
from typing import Any, Coroutine, Dict, Iterable, List, Literal, Optional, Type, Union

import httpx
from pydantic import BaseModel

import litellm
from litellm.constants import request_timeout
Expand Down Expand Up @@ -135,9 +136,10 @@ async def aresponses_api_with_mcp(
)

# Parse MCP tools and separate from other tools
mcp_tools_with_litellm_proxy, other_tools = (
LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
)
(
mcp_tools_with_litellm_proxy,
other_tools,
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)

# Get available tools from MCP manager if we have MCP tools
openai_tools = []
Expand Down Expand Up @@ -254,6 +256,7 @@ async def aresponses(
stream: Optional[bool] = None,
temperature: Optional[float] = None,
text: Optional["ResponseText"] = None,
text_format: Optional[Union[Type["BaseModel"], dict]] = None,
tool_choice: Optional[ToolChoice] = None,
tools: Optional[Iterable[ToolParam]] = None,
top_p: Optional[float] = None,
Expand All @@ -279,6 +282,14 @@ async def aresponses(
loop = asyncio.get_event_loop()
kwargs["aresponses"] = True

# Convert text_format to text parameter if provided
text = ResponsesAPIRequestUtils.convert_text_format_to_text_param(
text_format=text_format, text=text
)
if text is not None:
# Update local_vars to include the converted text parameter
local_vars["text"] = text

# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
Expand Down Expand Up @@ -367,6 +378,7 @@ def responses(
stream: Optional[bool] = None,
temperature: Optional[float] = None,
text: Optional["ResponseText"] = None,
text_format: Optional[Union[Type["BaseModel"], dict]] = None,
tool_choice: Optional[ToolChoice] = None,
tools: Optional[Iterable[ToolParam]] = None,
top_p: Optional[float] = None,
Expand Down Expand Up @@ -399,6 +411,14 @@ def responses(
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("aresponses", False) is True

# Convert text_format to text parameter if provided
text = ResponsesAPIRequestUtils.convert_text_format_to_text_param(
text_format=text_format, text=text
)
if text is not None:
# Update local_vars to include the converted text parameter
local_vars["text"] = text

# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)

Expand Down Expand Up @@ -432,11 +452,11 @@ def responses(
)

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)

local_vars.update(kwargs)
Expand Down Expand Up @@ -628,11 +648,11 @@ def delete_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -807,11 +827,11 @@ def get_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -963,11 +983,11 @@ def list_input_items(
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required but passed as None")

responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down
50 changes: 48 additions & 2 deletions litellm/responses/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import base64
from typing import Any, Dict, List, Optional, Union, cast, get_type_hints, overload
from typing import (
Any,
Dict,
List,
Optional,
Type,
Union,
cast,
get_type_hints,
overload,
)

from pydantic import BaseModel

import litellm
from litellm._logging import verbose_logger
Expand All @@ -8,6 +20,7 @@
ResponseAPIUsage,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponseText,
)
from litellm.types.responses.main import DecodedResponseId
from litellm.types.utils import SpecialEnums, Usage
Expand All @@ -24,7 +37,6 @@ def _check_valid_arg(
custom_llm_provider: Optional[str],
model: str,
):

if supported_params is None:
return
unsupported_params = {}
Expand Down Expand Up @@ -302,6 +314,40 @@ def decode_previous_response_id_to_original_previous_response_id(
)
return decoded_response_id.get("response_id", previous_response_id)

@staticmethod
def convert_text_format_to_text_param(
text_format: Optional[Union[Type["BaseModel"], dict]],
text: Optional["ResponseText"] = None,
) -> Optional["ResponseText"]:
"""
Convert text_format parameter to text parameter for the responses API.

Args:
text_format: Pydantic model class or dict to convert to response format
text: Existing text parameter (if provided, text_format is ignored)

Returns:
ResponseText object with the converted format, or None if conversion fails
"""
if text_format is not None and text is None:
from litellm.llms.base_llm.base_utils import type_to_response_format_param

# Convert Pydantic model to response format
response_format = type_to_response_format_param(text_format)
if response_format is not None:
# Create ResponseText object with the format
# The responses API expects the format to have name at the top level
text = {
"format": {
"type": response_format["type"],
"name": response_format["json_schema"]["name"],
"schema": response_format["json_schema"]["schema"],
"strict": response_format["json_schema"]["strict"],
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Code Assumes Fixed Dictionary Structure

In convert_text_format_to_text_param, the code directly accesses nested keys within response_format (e.g., response_format["json_schema"]["name"], response_format["type"]). This assumes a specific structure for the dictionary returned by type_to_response_format_param. If the actual structure doesn't match, it will raise a KeyError.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type_to_response_format_param returns {
"type": "json_schema",
"json_schema": {
"schema": schema,
"name": response_format.name,
"strict": True,
},
}

so keys will be always be there

return text
return text


class ResponseAPILoggingUtils:
@staticmethod
Expand Down
1 change: 0 additions & 1 deletion tests/llm_responses_api_testing/base_responses_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,4 +590,3 @@ def test_openai_responses_api_dict_input_filtering(self):
assert function_call_item["status"] == "completed", "status value should be preserved"

print("✅ OpenAI Responses API dict input filtering test passed")

Original file line number Diff line number Diff line change
Expand Up @@ -667,5 +667,4 @@ def test_get_supported_openai_params():
assert "temperature" in params
assert "stream" in params
assert "background" in params
assert "stream" in params

assert "stream" in params
161 changes: 161 additions & 0 deletions tests/test_litellm/responses/test_text_format_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import json
import os
import sys

import pytest
from pydantic import BaseModel

sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path

import litellm
from litellm.types.llms.openai import (
IncompleteDetails,
ResponseAPIUsage,
ResponsesAPIResponse,
)


class TestTextFormatConversion:
"""Test text_format to text parameter conversion for responses API"""

def get_base_completion_call_args(self):
"""Get base arguments for completion call"""
return {
"model": "gpt-4o",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
}

@pytest.mark.asyncio
async def test_text_format_to_text_conversion(self):
"""
Test that when text_format parameter is passed to litellm.aresponses,
it gets converted to text parameter in the raw API call to OpenAI.
"""
from unittest.mock import AsyncMock, patch

class TestResponse(BaseModel):
"""Test Pydantic model for structured output"""

answer: str
confidence: float

class MockResponse:
"""Mock response class for testing"""

def __init__(self, json_data, status_code):
self._json_data = json_data
self.status_code = status_code
self.text = json.dumps(json_data)

def json(self):
return self._json_data

# Mock response from OpenAI
mock_response = {
"id": "resp_123",
"object": "response",
"created_at": 1741476542,
"status": "completed",
"model": "gpt-4o",
"output": [
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": '{"answer": "Paris", "confidence": 0.95}',
"annotations": [],
}
],
}
],
"parallel_tool_calls": True,
"usage": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"output_tokens_details": {"reasoning_tokens": 0},
},
"text": {"format": {"type": "json_object"}},
"error": None,
"incomplete_details": None,
"instructions": None,
"metadata": {},
"temperature": 1.0,
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"max_output_tokens": None,
"previous_response_id": None,
"reasoning": {"effort": None, "summary": None},
"truncation": "disabled",
"user": None,
}

base_completion_call_args = self.get_base_completion_call_args()

with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
# Configure the mock to return our response
mock_post.return_value = MockResponse(mock_response, 200)

litellm._turn_on_debug()
litellm.set_verbose = True

# Call aresponses with text_format parameter
response = await litellm.aresponses(
input="What is the capital of France?",
text_format=TestResponse,
**base_completion_call_args,
)

# Verify the request was made correctly
mock_post.assert_called_once()
request_body = mock_post.call_args.kwargs["json"]
print("Request body:", json.dumps(request_body, indent=4))

# Validate that text_format was converted to text parameter
assert (
"text" in request_body
), "text parameter should be present in request body"
assert (
"text_format" not in request_body
), "text_format should not be in request body"

# Validate the text parameter structure
text_param = request_body["text"]
assert "format" in text_param, "text parameter should have format field"
assert (
text_param["format"]["type"] == "json_schema"
), "format type should be json_schema"
assert "name" in text_param["format"], "format should have name field"
assert (
text_param["format"]["name"] == "TestResponse"
), "format name should match Pydantic model name"
assert "schema" in text_param["format"], "format should have schema field"
assert "strict" in text_param["format"], "format should have strict field"

# Validate the schema structure
schema = text_param["format"]["schema"]
assert schema["type"] == "object", "schema type should be object"
assert "properties" in schema, "schema should have properties"
assert (
"answer" in schema["properties"]
), "schema should have answer property"
assert (
"confidence" in schema["properties"]
), "schema should have confidence property"

# Validate other request parameters
assert request_body["input"] == "What is the capital of France?"

# Validate the response
print("Response:", json.dumps(response, indent=4, default=str))
Loading