@@ -225,77 +225,15 @@ async def complete_prompt(
225
225
if self ._options .logger is not None :
226
226
self ._options .logger .debug (f"PROMPT:\n { res .output } " )
227
227
228
- messages : List [chat .ChatCompletionMessageParam ] = []
229
-
230
- for msg in res .output :
231
- param : Union [
232
- chat .ChatCompletionUserMessageParam ,
233
- chat .ChatCompletionAssistantMessageParam ,
234
- chat .ChatCompletionSystemMessageParam ,
235
- chat .ChatCompletionToolMessageParam ,
236
- ] = chat .ChatCompletionUserMessageParam (
237
- role = "user" ,
238
- content = msg .content if msg .content is not None else "" ,
239
- )
240
-
241
- if msg .name :
242
- setattr (param , "name" , msg .name )
243
-
244
- if msg .role == "assistant" :
245
- param = chat .ChatCompletionAssistantMessageParam (
246
- role = "assistant" ,
247
- content = msg .content if msg .content is not None else "" ,
248
- )
249
-
250
- tool_call_params : List [chat .ChatCompletionMessageToolCallParam ] = []
251
-
252
- if msg .action_calls and len (msg .action_calls ) > 0 :
253
- for tool_call in msg .action_calls :
254
- tool_call_params .append (
255
- chat .ChatCompletionMessageToolCallParam (
256
- id = tool_call .id ,
257
- function = Function (
258
- name = tool_call .function .name ,
259
- arguments = tool_call .function .arguments ,
260
- ),
261
- type = tool_call .type ,
262
- )
263
- )
264
- param ["content" ] = None
265
- param ["tool_calls" ] = tool_call_params
266
-
267
- if msg .name :
268
- param ["name" ] = msg .name
269
-
270
- elif msg .role == "tool" :
271
- param = chat .ChatCompletionToolMessageParam (
272
- role = "tool" ,
273
- tool_call_id = msg .action_call_id if msg .action_call_id else "" ,
274
- content = msg .content if msg .content else "" ,
275
- )
276
- elif msg .role == "system" :
277
- # o1 models do not support system messages
278
- if is_o1_model :
279
- param = chat .ChatCompletionUserMessageParam (
280
- role = "user" ,
281
- content = msg .content if msg .content is not None else "" ,
282
- )
283
- else :
284
- param = chat .ChatCompletionSystemMessageParam (
285
- role = "system" ,
286
- content = msg .content if msg .content is not None else "" ,
287
- )
288
-
289
- if msg .name :
290
- param ["name" ] = msg .name
291
-
292
- messages .append (param )
228
+ messages : List [chat .ChatCompletionMessageParam ]
229
+ messages = self ._map_messages (res .output , is_o1_model )
293
230
294
231
try :
295
232
extra_body = {}
296
233
if template .config .completion .data_sources is not None :
297
234
extra_body ["data_sources" ] = template .config .completion .data_sources
298
235
236
+ max_tokens = template .config .completion .max_tokens
299
237
completion = await self ._client .chat .completions .create (
300
238
messages = messages ,
301
239
model = model ,
@@ -305,7 +243,8 @@ async def complete_prompt(
305
243
frequency_penalty = template .config .completion .frequency_penalty ,
306
244
top_p = template .config .completion .top_p if not is_o1_model else 1 ,
307
245
temperature = template .config .completion .temperature if not is_o1_model else 1 ,
308
- max_completion_tokens = template .config .completion .max_tokens ,
246
+ max_tokens = max_tokens if not is_o1_model else NOT_GIVEN ,
247
+ max_completion_tokens = max_tokens if is_o1_model else NOT_GIVEN ,
309
248
tools = tools if len (tools ) > 0 else NOT_GIVEN ,
310
249
tool_choice = tool_choice if len (tools ) > 0 else NOT_GIVEN ,
311
250
parallel_tool_calls = parallel_tool_calls if len (tools ) > 0 else NOT_GIVEN ,
@@ -436,3 +375,70 @@ async def complete_prompt(
436
375
status of { err .code } : { err .message }
437
376
""" ,
438
377
)
378
+
379
+ def _map_messages (self , msgs : List [Message ], is_o1_model : bool ):
380
+ output = []
381
+ for msg in msgs :
382
+ param : Union [
383
+ chat .ChatCompletionUserMessageParam ,
384
+ chat .ChatCompletionAssistantMessageParam ,
385
+ chat .ChatCompletionSystemMessageParam ,
386
+ chat .ChatCompletionToolMessageParam ,
387
+ ] = chat .ChatCompletionUserMessageParam (
388
+ role = "user" ,
389
+ content = msg .content if msg .content is not None else "" ,
390
+ )
391
+
392
+ if msg .name :
393
+ setattr (param , "name" , msg .name )
394
+
395
+ if msg .role == "assistant" :
396
+ param = chat .ChatCompletionAssistantMessageParam (
397
+ role = "assistant" ,
398
+ content = msg .content if msg .content is not None else "" ,
399
+ )
400
+
401
+ tool_call_params : List [chat .ChatCompletionMessageToolCallParam ] = []
402
+
403
+ if msg .action_calls and len (msg .action_calls ) > 0 :
404
+ for tool_call in msg .action_calls :
405
+ tool_call_params .append (
406
+ chat .ChatCompletionMessageToolCallParam (
407
+ id = tool_call .id ,
408
+ function = Function (
409
+ name = tool_call .function .name ,
410
+ arguments = tool_call .function .arguments ,
411
+ ),
412
+ type = tool_call .type ,
413
+ )
414
+ )
415
+ param ["content" ] = None
416
+ param ["tool_calls" ] = tool_call_params
417
+
418
+ if msg .name :
419
+ param ["name" ] = msg .name
420
+
421
+ elif msg .role == "tool" :
422
+ param = chat .ChatCompletionToolMessageParam (
423
+ role = "tool" ,
424
+ tool_call_id = msg .action_call_id if msg .action_call_id else "" ,
425
+ content = msg .content if msg .content else "" ,
426
+ )
427
+ elif msg .role == "system" :
428
+ # o1 models do not support system messages
429
+ if is_o1_model :
430
+ param = chat .ChatCompletionUserMessageParam (
431
+ role = "user" ,
432
+ content = msg .content if msg .content is not None else "" ,
433
+ )
434
+ else :
435
+ param = chat .ChatCompletionSystemMessageParam (
436
+ role = "system" ,
437
+ content = msg .content if msg .content is not None else "" ,
438
+ )
439
+
440
+ if msg .name :
441
+ param ["name" ] = msg .name
442
+
443
+ output .append (param )
444
+ return output
0 commit comments