diff --git a/python/packages/autogen-core/src/autogen_core/models/_types.py b/python/packages/autogen-core/src/autogen_core/models/_types.py index fb118562e4d3..a3d6af1edde4 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_types.py +++ b/python/packages/autogen-core/src/autogen_core/models/_types.py @@ -52,7 +52,7 @@ class RequestUsage: completion_tokens: int -FinishReasons = Literal["stop", "length", "function_calls", "content_filter"] +FinishReasons = Literal["stop", "length", "function_calls", "content_filter", "unknown"] @dataclass diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 5b9f51129a88..ad9ce6e84712 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -30,6 +30,7 @@ Image, MessageHandlerContext, ) +from autogen_core.models import FinishReasons from autogen_core.logging import LLMCallEvent from autogen_core.models import ( AssistantMessage, @@ -327,6 +328,21 @@ def assert_valid_name(name: str) -> str: return name +def normalize_stop_reason(stop_reason: str | None) -> FinishReasons: + if stop_reason is None: + return "unknown" + + # Convert to lower case + stop_reason = stop_reason.lower() + + KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = { + "end_turn": "stop", + "tool_calls": "function_calls", + } + + return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown") + + class BaseOpenAIChatCompletionClient(ChatCompletionClient): def __init__( self, @@ -747,8 +763,8 @@ async def create_stream( else: prompt_tokens = 0 - if stop_reason is None: - raise ValueError("No stop reason found") + if stop_reason == "function_call": + raise ValueError("Function calls are not supported in this context") content: Union[str, List[FunctionCall]] if len(content_deltas) > 1: @@ -770,13 +786,9 @@ async def create_stream( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) - if stop_reason == "function_call": - raise ValueError("Function calls are not supported in this context") - if stop_reason == "tool_calls": - stop_reason = "function_calls" result = CreateResult( - finish_reason=stop_reason, # type: ignore + finish_reason=normalize_stop_reason(stop_reason), content=content, usage=usage, cached=False,