Skip to content

Commit afdd0aa

Browse files
authored
[PY] fix: Map tokens config to max_tokens when non-o1 model is used. (#2151)
## Linked issues closes: #minor ## Details Using the `max_completion_tokens` field is not supported for non-o1 models in Azure OpenAI. #### Change details * In `openai_model.py` if the model is not in the o1 series, then use `max_tokens` field by default. ## Attestation Checklist - [x] My code follows the style guidelines of this project - I have checked for/fixed spelling, linting, and other errors - I have commented my code for clarity - I have made corresponding changes to the documentation (updating the doc strings in the code is sufficient) - My changes generate no new warnings - I have added tests that validates my changes, and provides sufficient test coverage. I have tested with: - Local testing - E2E testing in Teams - New and existing unit tests pass locally with my changes
1 parent 4b009fb commit afdd0aa

File tree

4 files changed

+154
-79
lines changed

4 files changed

+154
-79
lines changed

python/packages/ai/teams/ai/models/openai_model.py

Lines changed: 72 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -225,77 +225,15 @@ async def complete_prompt(
225225
if self._options.logger is not None:
226226
self._options.logger.debug(f"PROMPT:\n{res.output}")
227227

228-
messages: List[chat.ChatCompletionMessageParam] = []
229-
230-
for msg in res.output:
231-
param: Union[
232-
chat.ChatCompletionUserMessageParam,
233-
chat.ChatCompletionAssistantMessageParam,
234-
chat.ChatCompletionSystemMessageParam,
235-
chat.ChatCompletionToolMessageParam,
236-
] = chat.ChatCompletionUserMessageParam(
237-
role="user",
238-
content=msg.content if msg.content is not None else "",
239-
)
240-
241-
if msg.name:
242-
setattr(param, "name", msg.name)
243-
244-
if msg.role == "assistant":
245-
param = chat.ChatCompletionAssistantMessageParam(
246-
role="assistant",
247-
content=msg.content if msg.content is not None else "",
248-
)
249-
250-
tool_call_params: List[chat.ChatCompletionMessageToolCallParam] = []
251-
252-
if msg.action_calls and len(msg.action_calls) > 0:
253-
for tool_call in msg.action_calls:
254-
tool_call_params.append(
255-
chat.ChatCompletionMessageToolCallParam(
256-
id=tool_call.id,
257-
function=Function(
258-
name=tool_call.function.name,
259-
arguments=tool_call.function.arguments,
260-
),
261-
type=tool_call.type,
262-
)
263-
)
264-
param["content"] = None
265-
param["tool_calls"] = tool_call_params
266-
267-
if msg.name:
268-
param["name"] = msg.name
269-
270-
elif msg.role == "tool":
271-
param = chat.ChatCompletionToolMessageParam(
272-
role="tool",
273-
tool_call_id=msg.action_call_id if msg.action_call_id else "",
274-
content=msg.content if msg.content else "",
275-
)
276-
elif msg.role == "system":
277-
# o1 models do not support system messages
278-
if is_o1_model:
279-
param = chat.ChatCompletionUserMessageParam(
280-
role="user",
281-
content=msg.content if msg.content is not None else "",
282-
)
283-
else:
284-
param = chat.ChatCompletionSystemMessageParam(
285-
role="system",
286-
content=msg.content if msg.content is not None else "",
287-
)
288-
289-
if msg.name:
290-
param["name"] = msg.name
291-
292-
messages.append(param)
228+
messages: List[chat.ChatCompletionMessageParam]
229+
messages = self._map_messages(res.output, is_o1_model)
293230

294231
try:
295232
extra_body = {}
296233
if template.config.completion.data_sources is not None:
297234
extra_body["data_sources"] = template.config.completion.data_sources
298235

236+
max_tokens = template.config.completion.max_tokens
299237
completion = await self._client.chat.completions.create(
300238
messages=messages,
301239
model=model,
@@ -305,7 +243,8 @@ async def complete_prompt(
305243
frequency_penalty=template.config.completion.frequency_penalty,
306244
top_p=template.config.completion.top_p if not is_o1_model else 1,
307245
temperature=template.config.completion.temperature if not is_o1_model else 1,
308-
max_completion_tokens=template.config.completion.max_tokens,
246+
max_tokens=max_tokens if not is_o1_model else NOT_GIVEN,
247+
max_completion_tokens=max_tokens if is_o1_model else NOT_GIVEN,
309248
tools=tools if len(tools) > 0 else NOT_GIVEN,
310249
tool_choice=tool_choice if len(tools) > 0 else NOT_GIVEN,
311250
parallel_tool_calls=parallel_tool_calls if len(tools) > 0 else NOT_GIVEN,
@@ -436,3 +375,70 @@ async def complete_prompt(
436375
status of {err.code}: {err.message}
437376
""",
438377
)
378+
379+
def _map_messages(self, msgs: List[Message], is_o1_model: bool):
380+
output = []
381+
for msg in msgs:
382+
param: Union[
383+
chat.ChatCompletionUserMessageParam,
384+
chat.ChatCompletionAssistantMessageParam,
385+
chat.ChatCompletionSystemMessageParam,
386+
chat.ChatCompletionToolMessageParam,
387+
] = chat.ChatCompletionUserMessageParam(
388+
role="user",
389+
content=msg.content if msg.content is not None else "",
390+
)
391+
392+
if msg.name:
393+
setattr(param, "name", msg.name)
394+
395+
if msg.role == "assistant":
396+
param = chat.ChatCompletionAssistantMessageParam(
397+
role="assistant",
398+
content=msg.content if msg.content is not None else "",
399+
)
400+
401+
tool_call_params: List[chat.ChatCompletionMessageToolCallParam] = []
402+
403+
if msg.action_calls and len(msg.action_calls) > 0:
404+
for tool_call in msg.action_calls:
405+
tool_call_params.append(
406+
chat.ChatCompletionMessageToolCallParam(
407+
id=tool_call.id,
408+
function=Function(
409+
name=tool_call.function.name,
410+
arguments=tool_call.function.arguments,
411+
),
412+
type=tool_call.type,
413+
)
414+
)
415+
param["content"] = None
416+
param["tool_calls"] = tool_call_params
417+
418+
if msg.name:
419+
param["name"] = msg.name
420+
421+
elif msg.role == "tool":
422+
param = chat.ChatCompletionToolMessageParam(
423+
role="tool",
424+
tool_call_id=msg.action_call_id if msg.action_call_id else "",
425+
content=msg.content if msg.content else "",
426+
)
427+
elif msg.role == "system":
428+
# o1 models do not support system messages
429+
if is_o1_model:
430+
param = chat.ChatCompletionUserMessageParam(
431+
role="user",
432+
content=msg.content if msg.content is not None else "",
433+
)
434+
else:
435+
param = chat.ChatCompletionSystemMessageParam(
436+
role="system",
437+
content=msg.content if msg.content is not None else "",
438+
)
439+
440+
if msg.name:
441+
param["name"] = msg.name
442+
443+
output.append(param)
444+
return output

python/packages/ai/tests/ai/models/test_openai_model.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,18 @@ class MockAsyncCompletions:
100100
should_error = False
101101
has_tool_call = False
102102
has_tool_calls = False
103-
is_o1_model = False
104-
messages = []
103+
create_params = None
105104

106105
def __init__(
107-
self, should_error=False, has_tool_call=False, has_tool_calls=False, is_o1_model=False
106+
self, should_error=False, has_tool_call=False, has_tool_calls=False
108107
) -> None:
109108
self.should_error = should_error
110109
self.has_tool_call = has_tool_call
111110
self.has_tool_calls = has_tool_calls
112-
self.is_o1_model = is_o1_model
113-
self.messages = []
114111

115112
async def create(self, **kwargs) -> chat.ChatCompletion:
113+
self.create_params = kwargs
114+
116115
if self.should_error:
117116
raise openai.BadRequestError(
118117
"bad request",
@@ -126,9 +125,6 @@ async def create(self, **kwargs) -> chat.ChatCompletion:
126125
if self.has_tool_calls:
127126
return await self.handle_tool_calls(**kwargs)
128127

129-
if self.is_o1_model:
130-
self.messages = kwargs["messages"]
131-
132128
return chat.ChatCompletion(
133129
id="",
134130
choices=[
@@ -294,7 +290,6 @@ async def test_should_be_success(self, mock_async_openai):
294290

295291
@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
296292
async def test_o1_model_should_use_user_message_over_system_message(self, mock_async_openai):
297-
mock_async_openai.return_value.chat.completions.is_o1_model = True
298293
context = self.create_mock_context()
299294
state = TurnState()
300295
state.temp = {}
@@ -319,8 +314,81 @@ async def test_o1_model_should_use_user_message_over_system_message(self, mock_a
319314

320315
self.assertTrue(mock_async_openai.called)
321316
self.assertEqual(res.status, "success")
317+
create_params = mock_async_openai.return_value.chat.completions.create_params
318+
self.assertEqual(
319+
create_params["messages"][0]["role"], "user"
320+
)
321+
322+
@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
323+
async def test_o1_model_should_use_max_completion_tokens_param(self, mock_async_openai):
324+
context = self.create_mock_context()
325+
state = TurnState()
326+
state.temp = {}
327+
state.conversation = {}
328+
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="o1-"))
329+
completion = CompletionConfig(completion_type="chat")
330+
completion.max_tokens = 1000
331+
res = await model.complete_prompt(
332+
context=context,
333+
memory=state,
334+
functions=cast(PromptFunctions, {}),
335+
tokenizer=GPTTokenizer(),
336+
template=PromptTemplate(
337+
name="default",
338+
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
339+
config=PromptTemplateConfig(
340+
schema=1.0,
341+
type="completion",
342+
description="test",
343+
completion=completion,
344+
),
345+
),
346+
)
347+
348+
self.assertTrue(mock_async_openai.called)
349+
self.assertEqual(res.status, "success")
350+
create_params = mock_async_openai.return_value.chat.completions.create_params
351+
self.assertEqual(
352+
create_params["max_completion_tokens"], 1000
353+
)
354+
self.assertEqual(
355+
create_params["max_tokens"], openai.NOT_GIVEN
356+
)
357+
358+
@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
359+
async def test_non_o1_model_should_use_max_tokens_param(self, mock_async_openai):
360+
context = self.create_mock_context()
361+
state = TurnState()
362+
state.temp = {}
363+
state.conversation = {}
364+
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="non-o1"))
365+
completion = CompletionConfig(completion_type="chat")
366+
completion.max_tokens = 1000
367+
res = await model.complete_prompt(
368+
context=context,
369+
memory=state,
370+
functions=cast(PromptFunctions, {}),
371+
tokenizer=GPTTokenizer(),
372+
template=PromptTemplate(
373+
name="default",
374+
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
375+
config=PromptTemplateConfig(
376+
schema=1.0,
377+
type="completion",
378+
description="test",
379+
completion=completion,
380+
),
381+
),
382+
)
383+
384+
self.assertTrue(mock_async_openai.called)
385+
self.assertEqual(res.status, "success")
386+
create_params = mock_async_openai.return_value.chat.completions.create_params
387+
self.assertEqual(
388+
create_params["max_tokens"], 1000
389+
)
322390
self.assertEqual(
323-
mock_async_openai.return_value.chat.completions.messages[0]["role"], "user"
391+
create_params["max_completion_tokens"], openai.NOT_GIVEN
324392
)
325393

326394
@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)

python/samples/04.ai.a.twentyQuestions/src/bot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141

4242
if config.OPENAI_KEY:
4343
model = OpenAIModel(
44-
OpenAIModelOptions(api_key=config.OPENAI_KEY, default_model="gpt-3.5-turbo")
44+
OpenAIModelOptions(api_key=config.OPENAI_KEY, default_model="gpt-4o")
4545
)
4646
elif config.AZURE_OPENAI_KEY and config.AZURE_OPENAI_ENDPOINT:
4747
model = OpenAIModel(
4848
AzureOpenAIModelOptions(
4949
api_key=config.AZURE_OPENAI_KEY,
50-
default_model="gpt-35-turbo",
51-
api_version="2023-03-15-preview",
50+
default_model="gpt-4o",
51+
api_version="2024-08-01-preview",
5252
endpoint=config.AZURE_OPENAI_ENDPOINT,
5353
)
5454
)

python/samples/04.ai.a.twentyQuestions/teamsapp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,4 @@ deploy:
9797
# You can replace it with your existing Azure Resource id
9898
# or add it to your environment variable file.
9999
resourceId: ${{BOT_AZURE_APP_SERVICE_RESOURCE_ID}}
100+
projectId: 38b5ad68-9f64-41b8-a503-4d3200655664

0 commit comments

Comments
 (0)