7
7
import dataclasses
8
8
import random
9
9
import string
10
+ import warnings
10
11
11
12
from contextlib import ExitStack
12
13
from typing import (
28
29
import numpy as np
29
30
import numpy .typing as npt
30
31
31
- import llama_cpp .llama as llama
32
- import llama_cpp .llama_types as llama_types
33
- import llama_cpp .llama_grammar as llama_grammar
32
+ from llama_cpp import llama , llama_grammar , llama_types
34
33
35
34
from ._logger import logger
36
35
from ._utils import suppress_stdout_stderr , Singleton
@@ -3373,6 +3372,155 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
3373
3372
)
3374
3373
3375
3374
3375
+ def _accumulate_chunks (
3376
+ chunks_iterator : Iterator [llama_types .CreateCompletionStreamResponse ],
3377
+ chunks_list : List [llama_types .CreateCompletionStreamResponse ],
3378
+ ) -> Iterator [llama_types .CreateCompletionStreamResponse ]:
3379
+ for chunk in chunks_iterator :
3380
+ chunks_list .append (chunk )
3381
+ yield chunk
3382
+
3383
+
3384
+ def _convert_chunks_to_completion (
3385
+ chunks : List [llama_types .CreateCompletionStreamResponse ],
3386
+ ) -> llama_types .CreateCompletionResponse :
3387
+ """Convert a list of completion chunks to a completion."""
3388
+ # Accumulate completion response values
3389
+ text : str = ""
3390
+ finish_reason : Optional [str ] = None
3391
+ logprobs : Optional [llama_types .CompletionLogprobs ] = None
3392
+ prompt_tokens = 0
3393
+ completion_tokens = 0
3394
+ total_tokens = 0
3395
+ completion_id : Optional [str ] = None
3396
+ completion_model : Optional [str ] = None
3397
+ completion_created : Optional [int ] = None
3398
+ for chunk in chunks :
3399
+ # Extract the id, model, and created values from the first chunk
3400
+ if completion_id is None :
3401
+ completion_id = chunk ["id" ]
3402
+ completion_model = chunk ["model" ]
3403
+ completion_created = chunk ["created" ]
3404
+ # Extract the usage if present in the chunk
3405
+ usage = chunk .get ("usage" )
3406
+ if usage :
3407
+ prompt_tokens += usage .get ("prompt_tokens" , 0 )
3408
+ completion_tokens += usage .get ("completion_tokens" , 0 )
3409
+ total_tokens += usage .get ("total_tokens" , 0 )
3410
+ # Accumulate the chunk text
3411
+ choice = chunk ["choices" ][0 ]
3412
+ text += choice .get ("text" , "" )
3413
+ # Extract the finish_reason and logprobs if present in the chunk
3414
+ if choice .get ("finish_reason" ):
3415
+ finish_reason = choice ["finish_reason" ]
3416
+ if choice .get ("logprobs" ):
3417
+ logprobs = choice ["logprobs" ]
3418
+ # Create the completion response
3419
+ completion : llama_types .CreateCompletionResponse = {
3420
+ "id" : completion_id or "unknown_id" ,
3421
+ "object" : "text_completion" ,
3422
+ "created" : completion_created or 0 ,
3423
+ "model" : completion_model or "unknown_model" ,
3424
+ "choices" : [
3425
+ {
3426
+ "text" : text ,
3427
+ "index" : 0 ,
3428
+ "logprobs" : logprobs , # TODO: Improve accumulation of logprobs
3429
+ "finish_reason" : finish_reason , # type: ignore[typeddict-item]
3430
+ }
3431
+ ],
3432
+ }
3433
+ # Add usage section if present in the chunks
3434
+ if (prompt_tokens + completion_tokens + total_tokens ) > 0 :
3435
+ completion ["usage" ] = {
3436
+ "prompt_tokens" : prompt_tokens ,
3437
+ "completion_tokens" : completion_tokens ,
3438
+ "total_tokens" : total_tokens ,
3439
+ }
3440
+ return completion
3441
+
3442
+
3443
+ def _stream_tool_calls (
3444
+ llama : llama .Llama ,
3445
+ prompt : str ,
3446
+ tools : List [llama_types .ChatCompletionTool ],
3447
+ tool_name : str ,
3448
+ completion_kwargs : dict [str , Any ],
3449
+ follow_up_gbnf_tool_grammar : str ,
3450
+ ) -> Iterator [llama_types .CreateChatCompletionStreamResponse ]:
3451
+ # Generate a tool call completions
3452
+ tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
3453
+ completions : List [llama_types .CreateCompletionResponse ] = []
3454
+ completions_tool_name : List [str ] = []
3455
+ finish_reason_chat_chunk = None
3456
+ while tool is not None :
3457
+ # Generate the parameter values for the selected tool
3458
+ prompt += f"functions.{ tool_name } :\n "
3459
+ try :
3460
+ grammar = llama_grammar .LlamaGrammar .from_json_schema (
3461
+ json .dumps (tool ["function" ]["parameters" ]), verbose = llama .verbose
3462
+ )
3463
+ except Exception as e :
3464
+ warnings .warn (
3465
+ f"Failed to parse function body as JSON schema, falling back to default grammar\n \n { e } " ,
3466
+ category = RuntimeWarning ,
3467
+ stacklevel = 2 ,
3468
+ )
3469
+ grammar = llama_grammar .LlamaGrammar .from_string (
3470
+ llama_grammar .JSON_GBNF , verbose = llama .verbose
3471
+ )
3472
+ completion_or_chunks = llama .create_completion (
3473
+ prompt = prompt ,
3474
+ ** {
3475
+ ** completion_kwargs ,
3476
+ "max_tokens" : None ,
3477
+ "grammar" : grammar ,
3478
+ },
3479
+ )
3480
+ chunks : List [llama_types .CreateCompletionResponse ] = []
3481
+ chat_chunks = _convert_completion_to_chat_function (
3482
+ tool_name ,
3483
+ _accumulate_chunks (completion_or_chunks , chunks ), # type: ignore[arg-type]
3484
+ stream = True ,
3485
+ )
3486
+ for chat_chunk in chat_chunks :
3487
+ # Don't return the finish_reason chunk
3488
+ if chat_chunk ["choices" ] and chat_chunk ["choices" ][0 ].get ("finish_reason" ):
3489
+ finish_reason_chat_chunk = chat_chunk
3490
+ break
3491
+ # Update this tool call's index
3492
+ if chat_chunk ["choices" ] and chat_chunk ["choices" ][0 ]["delta" ].get ("tool_calls" ):
3493
+ chat_chunk ["choices" ][0 ]["delta" ]["tool_calls" ][0 ]["index" ] = len (completions )
3494
+ yield chat_chunk
3495
+ completion = _convert_chunks_to_completion (chunks )
3496
+ completions .append (completion )
3497
+ completions_tool_name .append (tool_name )
3498
+ prompt += completion ["choices" ][0 ]["text" ]
3499
+ prompt += "\n "
3500
+ # Determine whether to call another tool or stop
3501
+ response = cast (
3502
+ llama_types .CreateCompletionResponse ,
3503
+ llama .create_completion (
3504
+ prompt = prompt ,
3505
+ ** {
3506
+ ** completion_kwargs ,
3507
+ "temperature" : 0 ,
3508
+ "stream" : False ,
3509
+ "stop" : [* completion_kwargs ["stop" ], ":" , "</function_calls>" ],
3510
+ "max_tokens" : None ,
3511
+ "grammar" : llama_grammar .LlamaGrammar .from_string (
3512
+ follow_up_gbnf_tool_grammar , verbose = llama .verbose
3513
+ ),
3514
+ },
3515
+ ),
3516
+ )
3517
+ tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
3518
+ tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
3519
+ # Yield the finish_reason chunk
3520
+ if finish_reason_chat_chunk is not None :
3521
+ yield finish_reason_chat_chunk
3522
+
3523
+
3376
3524
@register_chat_completion_handler ("chatml-function-calling" )
3377
3525
def chatml_function_calling (
3378
3526
llama : llama .Llama ,
@@ -3402,7 +3550,7 @@ def chatml_function_calling(
3402
3550
grammar : Optional [llama .LlamaGrammar ] = None ,
3403
3551
logprobs : Optional [bool ] = None ,
3404
3552
top_logprobs : Optional [int ] = None ,
3405
- ** kwargs , # type: ignore
3553
+ ** kwargs : Any ,
3406
3554
) -> Union [
3407
3555
llama_types .CreateChatCompletionResponse ,
3408
3556
Iterator [llama_types .CreateChatCompletionStreamResponse ],
@@ -3416,18 +3564,21 @@ def chatml_function_calling(
3416
3564
"{% if tool_calls %}"
3417
3565
"\n \n You have access to the following functions:\n "
3418
3566
"{% for tool in tools %}"
3567
+ '\n {% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
3419
3568
"\n functions.{{ tool.function.name }}:\n "
3420
3569
"{{ tool.function.parameters | tojson }}"
3421
3570
"\n {% endfor %}"
3422
- "\n \ n You can respond to users messages with either a single message or one or more function calls."
3423
- "\n \n To respond with a message begin the message with 'message:', use the following format:"
3571
+ "\n You must respond to user messages with either a single message or with one or more function calls."
3572
+ "\n \n To respond with a message use the following format:"
3424
3573
"\n \n message:"
3425
3574
"\n <message>"
3426
- "\n \n To respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
3427
- "\n \n functions.<function_name>:"
3575
+ "\n \n To respond with one or more function calls use the following format:"
3576
+ "\n \n <function_calls>"
3577
+ "\n functions.<function_name>:"
3428
3578
'\n { "arg1": "value1", "arg2": "value2" }'
3429
3579
"\n functions.<function_name>:"
3430
3580
'\n { "arg1": "value1", "arg2": "value2" }'
3581
+ "\n </function_calls>"
3431
3582
"{% endif %}"
3432
3583
"<|im_end|>\n "
3433
3584
"{% endif %}"
@@ -3438,7 +3589,7 @@ def chatml_function_calling(
3438
3589
"{% endif %}"
3439
3590
# Assistant message
3440
3591
"{% if message.role == 'assistant' %}"
3441
- ## Reglar message
3592
+ ## Regular message
3442
3593
"{% if message.content and message.content | length > 0 %}"
3443
3594
"{% if tool_calls %}"
3444
3595
"message:\n "
@@ -3465,352 +3616,235 @@ def chatml_function_calling(
3465
3616
3466
3617
# Convert legacy functions to tools
3467
3618
if functions is not None :
3468
- tools = [
3469
- {
3470
- "type" : "function" ,
3471
- "function" : function ,
3472
- }
3473
- for function in functions
3474
- ]
3619
+ tools = [{"type" : "function" , "function" : function } for function in functions ]
3475
3620
3476
3621
# Convert legacy function_call to tool_choice
3477
3622
if function_call is not None :
3478
- if isinstance (function_call , str ) and (
3479
- function_call == "none" or function_call == "auto"
3480
- ):
3623
+ if isinstance (function_call , str ) and (function_call in ("none" , "auto" )):
3481
3624
tool_choice = function_call
3482
3625
if isinstance (function_call , dict ) and "name" in function_call :
3483
- tool_choice = {
3484
- "type" : "function" ,
3485
- "function" : {
3486
- "name" : function_call ["name" ],
3487
- },
3488
- }
3626
+ tool_choice = {"type" : "function" , "function" : {"name" : function_call ["name" ]}}
3489
3627
3628
+ # Collect the llama.create_completion keyword arguments so we don't have to repeat these with
3629
+ # each completion call
3490
3630
stop = (
3491
3631
[stop , "<|im_end|>" ]
3492
3632
if isinstance (stop , str )
3493
- else stop + ["<|im_end|>" ] if stop else ["<|im_end|>" ]
3633
+ else [* stop , "<|im_end|>" ]
3634
+ if stop
3635
+ else ["<|im_end|>" ]
3494
3636
)
3637
+ grammar = ( # It is assumed the grammar applies to messages only, not tool calls
3638
+ grammar
3639
+ if grammar is not None
3640
+ else (
3641
+ _grammar_for_response_format (response_format )
3642
+ if response_format is not None and response_format ["type" ] == "json_object"
3643
+ else None
3644
+ )
3645
+ )
3646
+ completion_kwargs = {
3647
+ "temperature" : temperature ,
3648
+ "top_p" : top_p ,
3649
+ "top_k" : top_k ,
3650
+ "min_p" : min_p ,
3651
+ "typical_p" : typical_p ,
3652
+ "stream" : stream ,
3653
+ "stop" : stop ,
3654
+ "max_tokens" : max_tokens ,
3655
+ "presence_penalty" : presence_penalty ,
3656
+ "frequency_penalty" : frequency_penalty ,
3657
+ "repeat_penalty" : repeat_penalty ,
3658
+ "tfs_z" : tfs_z ,
3659
+ "mirostat_mode" : mirostat_mode ,
3660
+ "mirostat_tau" : mirostat_tau ,
3661
+ "mirostat_eta" : mirostat_eta ,
3662
+ "model" : model ,
3663
+ "logits_processor" : logits_processor ,
3664
+ "grammar" : grammar ,
3665
+ }
3495
3666
3496
- # Case 1: No tool choice by user
3667
+ # Case 1: No tool use
3497
3668
if (
3498
3669
tool_choice is None
3499
3670
or (isinstance (tool_choice , str ) and tool_choice == "none" )
3500
3671
or tools is None
3501
3672
or len (tools ) == 0
3502
3673
):
3503
3674
prompt = template_renderer .render (
3504
- messages = messages ,
3505
- tools = [],
3506
- tool_calls = None ,
3507
- add_generation_prompt = True ,
3675
+ messages = messages , tools = [], tool_calls = None , add_generation_prompt = True
3508
3676
)
3509
-
3510
- if response_format is not None and response_format ["type" ] == "json_object" :
3511
- grammar = _grammar_for_response_format (response_format )
3512
-
3513
3677
return _convert_completion_to_chat (
3514
3678
llama .create_completion (
3515
3679
prompt = prompt ,
3516
- temperature = temperature ,
3517
- top_p = top_p ,
3518
- top_k = top_k ,
3519
- min_p = min_p ,
3520
- typical_p = typical_p ,
3521
- stream = stream ,
3522
- stop = stop ,
3523
- max_tokens = max_tokens ,
3524
- presence_penalty = presence_penalty ,
3525
- frequency_penalty = frequency_penalty ,
3526
- repeat_penalty = repeat_penalty ,
3527
- tfs_z = tfs_z ,
3528
- mirostat_mode = mirostat_mode ,
3529
- mirostat_tau = mirostat_tau ,
3530
- mirostat_eta = mirostat_eta ,
3531
- model = model ,
3532
- logits_processor = logits_processor ,
3533
- grammar = grammar ,
3680
+ ** completion_kwargs , # type: ignore[arg-type]
3534
3681
logprobs = top_logprobs if logprobs else None ,
3535
3682
),
3536
3683
stream = stream ,
3537
3684
)
3538
3685
3539
- # Case 2: Tool choice by user
3540
- if isinstance (tool_choice , dict ):
3541
- tool_name = tool_choice ["function" ]["name" ]
3542
- tool = next (
3543
- (tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None
3544
- )
3545
- if tool is None :
3546
- raise ValueError (f"Tool with name '{ tool_name } ' not found in tools" )
3547
- prompt = template_renderer .render (
3548
- messages = messages ,
3549
- tools = tools ,
3550
- tool_calls = True ,
3551
- add_generation_prompt = True ,
3552
- )
3553
- prompt += f"functions.{ tool_name } :\n "
3554
- try :
3555
- grammar = llama_grammar .LlamaGrammar .from_json_schema (
3556
- json .dumps (tool ["function" ]["parameters" ]), verbose = llama .verbose
3557
- )
3558
- except Exception as e :
3559
- grammar = llama_grammar .LlamaGrammar .from_string (
3560
- llama_grammar .JSON_GBNF , verbose = llama .verbose
3561
- )
3562
- if llama .verbose :
3563
- print (
3564
- "Failed to parse function body as JSON schema, falling back to default grammar"
3565
- )
3566
- print (e )
3567
- completion_or_chunks = llama .create_completion (
3568
- prompt = prompt ,
3569
- temperature = temperature ,
3570
- top_p = top_p ,
3571
- top_k = top_k ,
3572
- min_p = min_p ,
3573
- typical_p = typical_p ,
3574
- stream = stream ,
3575
- stop = stop ,
3576
- max_tokens = max_tokens ,
3577
- presence_penalty = presence_penalty ,
3578
- frequency_penalty = frequency_penalty ,
3579
- repeat_penalty = repeat_penalty ,
3580
- tfs_z = tfs_z ,
3581
- mirostat_mode = mirostat_mode ,
3582
- mirostat_tau = mirostat_tau ,
3583
- mirostat_eta = mirostat_eta ,
3584
- model = model ,
3585
- logits_processor = logits_processor ,
3586
- grammar = grammar ,
3587
- )
3588
- return _convert_completion_to_chat_function (
3589
- tool_name , completion_or_chunks , stream
3590
- )
3686
+ # Ensure there is a system prompt to attach the tool metadata to
3687
+ if not any (message ["role" ] == "system" for message in messages ):
3688
+ messages = [* messages , {"role" : "system" , "content" : "" }]
3591
3689
3592
- # Case 3: Automatic tool choice
3593
- assert isinstance (tool_choice , str ) and tool_choice == "auto"
3594
- function_names = " | " .join (
3595
- [f'''"functions.{ tool ['function' ]['name' ]} :"''' for tool in tools ]
3690
+ # Case 2: Automatic or fixed tool choice
3691
+ # Case 2 step 1: Determine whether to respond with a message or a tool call
3692
+ assert (isinstance (tool_choice , str ) and tool_choice == "auto" ) or isinstance (tool_choice , dict )
3693
+ if isinstance (tool_choice , dict ):
3694
+ tools = [t for t in tools if t ["function" ]["name" ] == tool_choice ["function" ]["name" ]]
3695
+ assert tools
3696
+ function_names = " | " .join ([f'''"functions.{ t ['function' ]['name' ]} :"''' for t in tools ])
3697
+ prompt = template_renderer .render (
3698
+ messages = messages , tools = tools , tool_calls = True , add_generation_prompt = True
3596
3699
)
3597
3700
initial_gbnf_tool_grammar = (
3598
- """root ::= functions | "message:"\n """
3599
- f"""functions ::= { function_names } \n """
3600
- )
3601
- follow_up_gbnf_tool_grammar = (
3602
- """root ::= functions | "<|im_end|>"\n """
3603
- f"""functions ::= { function_names } \n """
3604
- )
3605
- prompt = template_renderer .render (
3606
- messages = messages ,
3607
- tools = tools ,
3608
- tool_calls = True ,
3609
- add_generation_prompt = True ,
3701
+ (
3702
+ 'root ::= "<function_calls>" "\\ n" functions | "message:"\n '
3703
+ f"functions ::= { function_names } \n "
3704
+ )
3705
+ if tool_choice == "auto"
3706
+ else f'root ::= "<function_calls>" "\\ n" functions\n functions ::= { function_names } \n '
3610
3707
)
3611
- completion_or_chunks = llama .create_completion (
3612
- prompt = prompt ,
3613
- temperature = 0 ,
3614
- top_p = top_p ,
3615
- top_k = top_k ,
3616
- min_p = min_p ,
3617
- typical_p = typical_p ,
3618
- stream = False ,
3619
- stop = [":" ],
3620
- max_tokens = None ,
3621
- presence_penalty = presence_penalty ,
3622
- frequency_penalty = frequency_penalty ,
3623
- repeat_penalty = repeat_penalty ,
3624
- tfs_z = tfs_z ,
3625
- mirostat_mode = mirostat_mode ,
3626
- mirostat_tau = mirostat_tau ,
3627
- mirostat_eta = mirostat_eta ,
3628
- model = model ,
3629
- logits_processor = logits_processor ,
3630
- grammar = llama_grammar .LlamaGrammar .from_string (
3631
- initial_gbnf_tool_grammar , verbose = llama .verbose
3708
+ completion = cast (
3709
+ llama_types .CreateCompletionResponse ,
3710
+ llama .create_completion (
3711
+ prompt = prompt ,
3712
+ ** { # type: ignore[arg-type]
3713
+ ** completion_kwargs ,
3714
+ "temperature" : 0 ,
3715
+ "stream" : False ,
3716
+ "stop" : [":" ],
3717
+ "max_tokens" : None ,
3718
+ "grammar" : llama_grammar .LlamaGrammar .from_string (
3719
+ initial_gbnf_tool_grammar , verbose = llama .verbose
3720
+ ),
3721
+ },
3632
3722
),
3633
3723
)
3634
- completion : llama_types .CreateCompletionResponse = completion_or_chunks # type: ignore
3635
3724
text = completion ["choices" ][0 ]["text" ]
3636
- if "message" in text :
3725
+ tool_name = None if text .startswith ("message" ) else text .split ("\n " )[- 1 ][len ("functions." ) :]
3726
+
3727
+ # Case 2 step 2A: Respond with a message
3728
+ if tool_name is None :
3637
3729
return _convert_completion_to_chat (
3638
3730
llama .create_completion (
3639
3731
prompt = prompt + "message:\n " ,
3640
- temperature = temperature ,
3641
- top_p = top_p ,
3642
- top_k = top_k ,
3643
- min_p = min_p ,
3644
- typical_p = typical_p ,
3645
- stream = stream ,
3646
- stop = ["<|im_end|>" ],
3732
+ ** completion_kwargs , # type: ignore[arg-type]
3647
3733
logprobs = top_logprobs if logprobs else None ,
3648
- max_tokens = None ,
3649
- presence_penalty = presence_penalty ,
3650
- frequency_penalty = frequency_penalty ,
3651
- repeat_penalty = repeat_penalty ,
3652
- tfs_z = tfs_z ,
3653
- mirostat_mode = mirostat_mode ,
3654
- mirostat_tau = mirostat_tau ,
3655
- mirostat_eta = mirostat_eta ,
3656
- model = model ,
3657
- logits_processor = logits_processor ,
3658
- grammar = llama_grammar .LlamaGrammar .from_string (
3659
- follow_up_gbnf_tool_grammar , verbose = llama .verbose
3660
- ),
3661
3734
),
3662
3735
stream = stream ,
3663
3736
)
3664
3737
3665
- # One or more function calls
3666
- tool_name = text [len ("functions." ) :]
3738
+ # Case 2 step 2B: One or more function calls
3739
+ follow_up_gbnf_tool_grammar = (
3740
+ 'root ::= functions | "</function_calls>" | "<|im_end|>"\n '
3741
+ f"functions ::= { function_names } \n "
3742
+ )
3743
+ prompt += "<function_calls>\n "
3744
+ if stream :
3745
+ return _stream_tool_calls (
3746
+ llama , prompt , tools , tool_name , completion_kwargs , follow_up_gbnf_tool_grammar
3747
+ )
3667
3748
tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
3668
- if not stream :
3669
- completions : List [llama_types .CreateCompletionResponse ] = []
3670
- completions_tool_name : List [str ] = []
3671
- while tool is not None :
3672
- prompt += f"functions.{ tool_name } :\n "
3673
- try :
3674
- grammar = llama_grammar .LlamaGrammar .from_json_schema (
3675
- json .dumps (tool ["function" ]["parameters" ]), verbose = llama .verbose
3676
- )
3677
- except Exception as e :
3678
- grammar = llama_grammar .LlamaGrammar .from_string (
3679
- llama_grammar .JSON_GBNF , verbose = llama .verbose
3680
- )
3681
- if llama .verbose :
3682
- print (
3683
- "Failed to parse function body as JSON schema, falling back to default grammar"
3684
- )
3685
- print (e )
3686
- completion_or_chunks = llama .create_completion (
3687
- prompt = prompt ,
3688
- temperature = temperature ,
3689
- top_p = top_p ,
3690
- top_k = top_k ,
3691
- min_p = min_p ,
3692
- typical_p = typical_p ,
3693
- stream = False ,
3694
- stop = stop ,
3695
- max_tokens = None ,
3696
- presence_penalty = presence_penalty ,
3697
- frequency_penalty = frequency_penalty ,
3698
- repeat_penalty = repeat_penalty ,
3699
- tfs_z = tfs_z ,
3700
- mirostat_mode = mirostat_mode ,
3701
- mirostat_tau = mirostat_tau ,
3702
- mirostat_eta = mirostat_eta ,
3703
- model = model ,
3704
- logits_processor = logits_processor ,
3705
- grammar = grammar ,
3706
- )
3707
- completion_or_chunks = cast (
3708
- llama_types .CreateCompletionResponse , completion_or_chunks
3749
+ completions : List [llama_types .CreateCompletionResponse ] = []
3750
+ completions_tool_name : List [str ] = []
3751
+ while tool is not None :
3752
+ # Generate the parameter values for the selected tool
3753
+ prompt += f"functions.{ tool_name } :\n "
3754
+ try :
3755
+ grammar = llama_grammar .LlamaGrammar .from_json_schema (
3756
+ json .dumps (tool ["function" ]["parameters" ]), verbose = llama .verbose
3709
3757
)
3710
- completions .append (completion_or_chunks )
3711
- completions_tool_name .append (tool_name )
3712
- prompt += completion_or_chunks ["choices" ][0 ]["text" ]
3713
- prompt += "\n "
3714
-
3715
- response = llama .create_completion (
3716
- prompt = prompt ,
3717
- temperature = temperature ,
3718
- top_p = top_p ,
3719
- top_k = top_k ,
3720
- min_p = min_p ,
3721
- typical_p = typical_p ,
3722
- stream = False ,
3723
- stop = stop ,
3724
- max_tokens = None ,
3725
- presence_penalty = presence_penalty ,
3726
- frequency_penalty = frequency_penalty ,
3727
- repeat_penalty = repeat_penalty ,
3728
- tfs_z = tfs_z ,
3729
- mirostat_mode = mirostat_mode ,
3730
- mirostat_tau = mirostat_tau ,
3731
- mirostat_eta = mirostat_eta ,
3732
- model = model ,
3733
- logits_processor = logits_processor ,
3734
- grammar = llama_grammar .LlamaGrammar .from_string (
3735
- follow_up_gbnf_tool_grammar , verbose = llama .verbose
3736
- ),
3758
+ except Exception as e :
3759
+ warnings .warn (
3760
+ f"Failed to parse function body as JSON schema, falling back to default grammar\n \n { e } " ,
3761
+ category = RuntimeWarning ,
3762
+ stacklevel = 2 ,
3737
3763
)
3738
- response = cast (llama_types .CreateCompletionResponse , response )
3739
-
3740
- tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
3741
- tool = next (
3742
- (tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None
3764
+ grammar = llama_grammar .LlamaGrammar .from_string (
3765
+ llama_grammar .JSON_GBNF , verbose = llama .verbose
3743
3766
)
3744
-
3745
- # Merge completions
3746
- function_call_dict : Union [
3747
- Dict [str , str ],
3748
- Dict [
3749
- Literal ["function_call" ],
3750
- llama_types .ChatCompletionRequestAssistantMessageFunctionCall ,
3751
- ],
3752
- ] = (
3767
+ completion_or_chunks = llama .create_completion (
3768
+ prompt = prompt ,
3769
+ ** { # type: ignore[arg-type]
3770
+ ** completion_kwargs ,
3771
+ "max_tokens" : None ,
3772
+ "grammar" : grammar ,
3773
+ },
3774
+ )
3775
+ completion = cast (llama_types .CreateCompletionResponse , completion_or_chunks )
3776
+ completions .append (completion )
3777
+ completions_tool_name .append (tool_name )
3778
+ prompt += completion ["choices" ][0 ]["text" ]
3779
+ prompt += "\n "
3780
+ # Determine whether to call another tool or stop
3781
+ response = cast (
3782
+ llama_types .CreateCompletionResponse ,
3783
+ llama .create_completion (
3784
+ prompt = prompt ,
3785
+ ** { # type: ignore[arg-type]
3786
+ ** completion_kwargs ,
3787
+ "temperature" : 0 ,
3788
+ "stream" : False ,
3789
+ "stop" : [* completion_kwargs ["stop" ], ":" , "</function_calls>" ], # type: ignore[misc]
3790
+ "max_tokens" : None ,
3791
+ "grammar" : llama_grammar .LlamaGrammar .from_string (
3792
+ follow_up_gbnf_tool_grammar , verbose = llama .verbose
3793
+ ),
3794
+ },
3795
+ ),
3796
+ )
3797
+ tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
3798
+ tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
3799
+ # Merge the completions into a single chat completion
3800
+ chat_completion : llama_types .CreateChatCompletionResponse = {
3801
+ "id" : "chat" + completion ["id" ],
3802
+ "object" : "chat.completion" ,
3803
+ "created" : completion ["created" ],
3804
+ "model" : completion ["model" ],
3805
+ "choices" : [
3753
3806
{
3754
- "function_call" : {
3755
- "name" : tool_name ,
3756
- "arguments" : completions [0 ]["choices" ][0 ]["text" ],
3757
- }
3807
+ "finish_reason" : "tool_calls" ,
3808
+ "index" : 0 ,
3809
+ "logprobs" : completion ["choices" ][0 ]["logprobs" ],
3810
+ "message" : {
3811
+ "role" : "assistant" ,
3812
+ "content" : None ,
3813
+ "tool_calls" : [
3814
+ {
3815
+ "id" : "call_" + f"_{ i } _" + tool_name + "_" + completion ["id" ],
3816
+ "type" : "function" ,
3817
+ "function" : {
3818
+ "name" : tool_name ,
3819
+ "arguments" : completion ["choices" ][0 ]["text" ],
3820
+ },
3821
+ }
3822
+ for i , (tool_name , completion ) in enumerate (
3823
+ zip (completions_tool_name , completions , strict = True )
3824
+ )
3825
+ ],
3826
+ },
3758
3827
}
3759
- if len (completions ) == 1
3760
- else {}
3761
- )
3762
- return {
3763
- "id" : "chat" + completion ["id" ],
3764
- "object" : "chat.completion" ,
3765
- "created" : completion ["created" ],
3766
- "model" : completion ["model" ],
3767
- "choices" : [
3768
- {
3769
- "finish_reason" : "tool_calls" ,
3770
- "index" : 0 ,
3771
- "logprobs" : _convert_text_completion_logprobs_to_chat (completion ["choices" ][0 ]["logprobs" ]),
3772
- "message" : {
3773
- "role" : "assistant" ,
3774
- "content" : None ,
3775
- "tool_calls" : [
3776
- {
3777
- "id" : "call_"
3778
- + f"_{ i } _"
3779
- + tool_name
3780
- + "_"
3781
- + completion ["id" ],
3782
- "type" : "function" ,
3783
- "function" : {
3784
- "name" : tool_name ,
3785
- "arguments" : completion ["choices" ][0 ]["text" ],
3786
- },
3787
- }
3788
- for i , (tool_name , completion ) in enumerate (
3789
- zip (completions_tool_name , completions )
3790
- )
3791
- ],
3792
- ** function_call_dict ,
3793
- },
3794
- }
3795
- ],
3796
- "usage" : {
3797
- "completion_tokens" : sum (
3798
- (
3799
- completion ["usage" ]["completion_tokens" ]
3800
- if "usage" in completion
3801
- else 0
3802
- )
3803
- for completion in completions
3804
- ),
3805
- "prompt_tokens" : sum (
3806
- completion ["usage" ]["prompt_tokens" ] if "usage" in completion else 0
3807
- for completion in completions
3808
- ),
3809
- "total_tokens" : sum (
3810
- completion ["usage" ]["total_tokens" ] if "usage" in completion else 0
3811
- for completion in completions
3812
- ),
3813
- },
3828
+ ],
3829
+ "usage" : {
3830
+ "completion_tokens" : sum (
3831
+ (completion ["usage" ]["completion_tokens" ] if "usage" in completion else 0 )
3832
+ for completion in completions
3833
+ ),
3834
+ "prompt_tokens" : sum (
3835
+ completion ["usage" ]["prompt_tokens" ] if "usage" in completion else 0
3836
+ for completion in completions
3837
+ ),
3838
+ "total_tokens" : sum (
3839
+ completion ["usage" ]["total_tokens" ] if "usage" in completion else 0
3840
+ for completion in completions
3841
+ ),
3842
+ },
3843
+ }
3844
+ if len (completions ) == 1 :
3845
+ single_function_call : llama_types .ChatCompletionResponseFunctionCall = {
3846
+ "name" : tool_name ,
3847
+ "arguments" : completions [0 ]["choices" ][0 ]["text" ],
3814
3848
}
3815
-
3816
- raise ValueError ( "Automatic streaming tool choice is not supported" )
3849
+ chat_completion [ "choices" ][ 0 ][ "message" ][ "function_call" ] = single_function_call
3850
+ return chat_completion
0 commit comments