Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a08a754

Browse files
committedMar 14, 2025
feat: add streaming tool use
1 parent 37eb5f0 commit a08a754

File tree

3 files changed

+473
-313
lines changed

3 files changed

+473
-313
lines changed
 

‎llama_cpp/llama_chat_format.py‎

Lines changed: 343 additions & 309 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dataclasses
88
import random
99
import string
10+
import warnings
1011

1112
from contextlib import ExitStack
1213
from typing import (
@@ -28,9 +29,7 @@
2829
import numpy as np
2930
import numpy.typing as npt
3031

31-
import llama_cpp.llama as llama
32-
import llama_cpp.llama_types as llama_types
33-
import llama_cpp.llama_grammar as llama_grammar
32+
from llama_cpp import llama, llama_grammar, llama_types
3433

3534
from ._logger import logger
3635
from ._utils import suppress_stdout_stderr, Singleton
@@ -3373,6 +3372,155 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
33733372
)
33743373

33753374

3375+
def _accumulate_chunks(
3376+
chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],
3377+
chunks_list: List[llama_types.CreateCompletionStreamResponse],
3378+
) -> Iterator[llama_types.CreateCompletionStreamResponse]:
3379+
for chunk in chunks_iterator:
3380+
chunks_list.append(chunk)
3381+
yield chunk
3382+
3383+
3384+
def _convert_chunks_to_completion(
3385+
chunks: List[llama_types.CreateCompletionStreamResponse],
3386+
) -> llama_types.CreateCompletionResponse:
3387+
"""Convert a list of completion chunks to a completion."""
3388+
# Accumulate completion response values
3389+
text: str = ""
3390+
finish_reason: Optional[str] = None
3391+
logprobs: Optional[llama_types.CompletionLogprobs] = None
3392+
prompt_tokens = 0
3393+
completion_tokens = 0
3394+
total_tokens = 0
3395+
completion_id: Optional[str] = None
3396+
completion_model: Optional[str] = None
3397+
completion_created: Optional[int] = None
3398+
for chunk in chunks:
3399+
# Extract the id, model, and created values from the first chunk
3400+
if completion_id is None:
3401+
completion_id = chunk["id"]
3402+
completion_model = chunk["model"]
3403+
completion_created = chunk["created"]
3404+
# Extract the usage if present in the chunk
3405+
usage = chunk.get("usage")
3406+
if usage:
3407+
prompt_tokens += usage.get("prompt_tokens", 0)
3408+
completion_tokens += usage.get("completion_tokens", 0)
3409+
total_tokens += usage.get("total_tokens", 0)
3410+
# Accumulate the chunk text
3411+
choice = chunk["choices"][0]
3412+
text += choice.get("text", "")
3413+
# Extract the finish_reason and logprobs if present in the chunk
3414+
if choice.get("finish_reason"):
3415+
finish_reason = choice["finish_reason"]
3416+
if choice.get("logprobs"):
3417+
logprobs = choice["logprobs"]
3418+
# Create the completion response
3419+
completion: llama_types.CreateCompletionResponse = {
3420+
"id": completion_id or "unknown_id",
3421+
"object": "text_completion",
3422+
"created": completion_created or 0,
3423+
"model": completion_model or "unknown_model",
3424+
"choices": [
3425+
{
3426+
"text": text,
3427+
"index": 0,
3428+
"logprobs": logprobs, # TODO: Improve accumulation of logprobs
3429+
"finish_reason": finish_reason, # type: ignore[typeddict-item]
3430+
}
3431+
],
3432+
}
3433+
# Add usage section if present in the chunks
3434+
if (prompt_tokens + completion_tokens + total_tokens) > 0:
3435+
completion["usage"] = {
3436+
"prompt_tokens": prompt_tokens,
3437+
"completion_tokens": completion_tokens,
3438+
"total_tokens": total_tokens,
3439+
}
3440+
return completion
3441+
3442+
3443+
def _stream_tool_calls(
3444+
llama: llama.Llama,
3445+
prompt: str,
3446+
tools: List[llama_types.ChatCompletionTool],
3447+
tool_name: str,
3448+
completion_kwargs: dict[str, Any],
3449+
follow_up_gbnf_tool_grammar: str,
3450+
) -> Iterator[llama_types.CreateChatCompletionStreamResponse]:
3451+
# Generate a tool call completions
3452+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
3453+
completions: List[llama_types.CreateCompletionResponse] = []
3454+
completions_tool_name: List[str] = []
3455+
finish_reason_chat_chunk = None
3456+
while tool is not None:
3457+
# Generate the parameter values for the selected tool
3458+
prompt += f"functions.{tool_name}:\n"
3459+
try:
3460+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
3461+
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
3462+
)
3463+
except Exception as e:
3464+
warnings.warn(
3465+
f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
3466+
category=RuntimeWarning,
3467+
stacklevel=2,
3468+
)
3469+
grammar = llama_grammar.LlamaGrammar.from_string(
3470+
llama_grammar.JSON_GBNF, verbose=llama.verbose
3471+
)
3472+
completion_or_chunks = llama.create_completion(
3473+
prompt=prompt,
3474+
**{
3475+
**completion_kwargs,
3476+
"max_tokens": None,
3477+
"grammar": grammar,
3478+
},
3479+
)
3480+
chunks: List[llama_types.CreateCompletionResponse] = []
3481+
chat_chunks = _convert_completion_to_chat_function(
3482+
tool_name,
3483+
_accumulate_chunks(completion_or_chunks, chunks), # type: ignore[arg-type]
3484+
stream=True,
3485+
)
3486+
for chat_chunk in chat_chunks:
3487+
# Don't return the finish_reason chunk
3488+
if chat_chunk["choices"] and chat_chunk["choices"][0].get("finish_reason"):
3489+
finish_reason_chat_chunk = chat_chunk
3490+
break
3491+
# Update this tool call's index
3492+
if chat_chunk["choices"] and chat_chunk["choices"][0]["delta"].get("tool_calls"):
3493+
chat_chunk["choices"][0]["delta"]["tool_calls"][0]["index"] = len(completions)
3494+
yield chat_chunk
3495+
completion = _convert_chunks_to_completion(chunks)
3496+
completions.append(completion)
3497+
completions_tool_name.append(tool_name)
3498+
prompt += completion["choices"][0]["text"]
3499+
prompt += "\n"
3500+
# Determine whether to call another tool or stop
3501+
response = cast(
3502+
llama_types.CreateCompletionResponse,
3503+
llama.create_completion(
3504+
prompt=prompt,
3505+
**{
3506+
**completion_kwargs,
3507+
"temperature": 0,
3508+
"stream": False,
3509+
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"],
3510+
"max_tokens": None,
3511+
"grammar": llama_grammar.LlamaGrammar.from_string(
3512+
follow_up_gbnf_tool_grammar, verbose=llama.verbose
3513+
),
3514+
},
3515+
),
3516+
)
3517+
tool_name = response["choices"][0]["text"][len("functions.") :]
3518+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
3519+
# Yield the finish_reason chunk
3520+
if finish_reason_chat_chunk is not None:
3521+
yield finish_reason_chat_chunk
3522+
3523+
33763524
@register_chat_completion_handler("chatml-function-calling")
33773525
def chatml_function_calling(
33783526
llama: llama.Llama,
@@ -3402,7 +3550,7 @@ def chatml_function_calling(
34023550
grammar: Optional[llama.LlamaGrammar] = None,
34033551
logprobs: Optional[bool] = None,
34043552
top_logprobs: Optional[int] = None,
3405-
**kwargs, # type: ignore
3553+
**kwargs: Any,
34063554
) -> Union[
34073555
llama_types.CreateChatCompletionResponse,
34083556
Iterator[llama_types.CreateChatCompletionStreamResponse],
@@ -3416,18 +3564,21 @@ def chatml_function_calling(
34163564
"{% if tool_calls %}"
34173565
"\n\nYou have access to the following functions:\n"
34183566
"{% for tool in tools %}"
3567+
'\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
34193568
"\nfunctions.{{ tool.function.name }}:\n"
34203569
"{{ tool.function.parameters | tojson }}"
34213570
"\n{% endfor %}"
3422-
"\n\nYou can respond to users messages with either a single message or one or more function calls."
3423-
"\n\nTo respond with a message begin the message with 'message:', use the following format:"
3571+
"\nYou must respond to user messages with either a single message or with one or more function calls."
3572+
"\n\nTo respond with a message use the following format:"
34243573
"\n\nmessage:"
34253574
"\n<message>"
3426-
"\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
3427-
"\n\nfunctions.<function_name>:"
3575+
"\n\nTo respond with one or more function calls use the following format:"
3576+
"\n\n<function_calls>"
3577+
"\nfunctions.<function_name>:"
34283578
'\n{ "arg1": "value1", "arg2": "value2" }'
34293579
"\nfunctions.<function_name>:"
34303580
'\n{ "arg1": "value1", "arg2": "value2" }'
3581+
"\n</function_calls>"
34313582
"{% endif %}"
34323583
"<|im_end|>\n"
34333584
"{% endif %}"
@@ -3438,7 +3589,7 @@ def chatml_function_calling(
34383589
"{% endif %}"
34393590
# Assistant message
34403591
"{% if message.role == 'assistant' %}"
3441-
## Reglar message
3592+
## Regular message
34423593
"{% if message.content and message.content | length > 0 %}"
34433594
"{% if tool_calls %}"
34443595
"message:\n"
@@ -3465,352 +3616,235 @@ def chatml_function_calling(
34653616

34663617
# Convert legacy functions to tools
34673618
if functions is not None:
3468-
tools = [
3469-
{
3470-
"type": "function",
3471-
"function": function,
3472-
}
3473-
for function in functions
3474-
]
3619+
tools = [{"type": "function", "function": function} for function in functions]
34753620

34763621
# Convert legacy function_call to tool_choice
34773622
if function_call is not None:
3478-
if isinstance(function_call, str) and (
3479-
function_call == "none" or function_call == "auto"
3480-
):
3623+
if isinstance(function_call, str) and (function_call in ("none", "auto")):
34813624
tool_choice = function_call
34823625
if isinstance(function_call, dict) and "name" in function_call:
3483-
tool_choice = {
3484-
"type": "function",
3485-
"function": {
3486-
"name": function_call["name"],
3487-
},
3488-
}
3626+
tool_choice = {"type": "function", "function": {"name": function_call["name"]}}
34893627

3628+
# Collect the llama.create_completion keyword arguments so we don't have to repeat these with
3629+
# each completion call
34903630
stop = (
34913631
[stop, "<|im_end|>"]
34923632
if isinstance(stop, str)
3493-
else stop + ["<|im_end|>"] if stop else ["<|im_end|>"]
3633+
else [*stop, "<|im_end|>"]
3634+
if stop
3635+
else ["<|im_end|>"]
34943636
)
3637+
grammar = ( # It is assumed the grammar applies to messages only, not tool calls
3638+
grammar
3639+
if grammar is not None
3640+
else (
3641+
_grammar_for_response_format(response_format)
3642+
if response_format is not None and response_format["type"] == "json_object"
3643+
else None
3644+
)
3645+
)
3646+
completion_kwargs = {
3647+
"temperature": temperature,
3648+
"top_p": top_p,
3649+
"top_k": top_k,
3650+
"min_p": min_p,
3651+
"typical_p": typical_p,
3652+
"stream": stream,
3653+
"stop": stop,
3654+
"max_tokens": max_tokens,
3655+
"presence_penalty": presence_penalty,
3656+
"frequency_penalty": frequency_penalty,
3657+
"repeat_penalty": repeat_penalty,
3658+
"tfs_z": tfs_z,
3659+
"mirostat_mode": mirostat_mode,
3660+
"mirostat_tau": mirostat_tau,
3661+
"mirostat_eta": mirostat_eta,
3662+
"model": model,
3663+
"logits_processor": logits_processor,
3664+
"grammar": grammar,
3665+
}
34953666

3496-
# Case 1: No tool choice by user
3667+
# Case 1: No tool use
34973668
if (
34983669
tool_choice is None
34993670
or (isinstance(tool_choice, str) and tool_choice == "none")
35003671
or tools is None
35013672
or len(tools) == 0
35023673
):
35033674
prompt = template_renderer.render(
3504-
messages=messages,
3505-
tools=[],
3506-
tool_calls=None,
3507-
add_generation_prompt=True,
3675+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
35083676
)
3509-
3510-
if response_format is not None and response_format["type"] == "json_object":
3511-
grammar = _grammar_for_response_format(response_format)
3512-
35133677
return _convert_completion_to_chat(
35143678
llama.create_completion(
35153679
prompt=prompt,
3516-
temperature=temperature,
3517-
top_p=top_p,
3518-
top_k=top_k,
3519-
min_p=min_p,
3520-
typical_p=typical_p,
3521-
stream=stream,
3522-
stop=stop,
3523-
max_tokens=max_tokens,
3524-
presence_penalty=presence_penalty,
3525-
frequency_penalty=frequency_penalty,
3526-
repeat_penalty=repeat_penalty,
3527-
tfs_z=tfs_z,
3528-
mirostat_mode=mirostat_mode,
3529-
mirostat_tau=mirostat_tau,
3530-
mirostat_eta=mirostat_eta,
3531-
model=model,
3532-
logits_processor=logits_processor,
3533-
grammar=grammar,
3680+
**completion_kwargs, # type: ignore[arg-type]
35343681
logprobs=top_logprobs if logprobs else None,
35353682
),
35363683
stream=stream,
35373684
)
35383685

3539-
# Case 2: Tool choice by user
3540-
if isinstance(tool_choice, dict):
3541-
tool_name = tool_choice["function"]["name"]
3542-
tool = next(
3543-
(tool for tool in tools if tool["function"]["name"] == tool_name), None
3544-
)
3545-
if tool is None:
3546-
raise ValueError(f"Tool with name '{tool_name}' not found in tools")
3547-
prompt = template_renderer.render(
3548-
messages=messages,
3549-
tools=tools,
3550-
tool_calls=True,
3551-
add_generation_prompt=True,
3552-
)
3553-
prompt += f"functions.{tool_name}:\n"
3554-
try:
3555-
grammar = llama_grammar.LlamaGrammar.from_json_schema(
3556-
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
3557-
)
3558-
except Exception as e:
3559-
grammar = llama_grammar.LlamaGrammar.from_string(
3560-
llama_grammar.JSON_GBNF, verbose=llama.verbose
3561-
)
3562-
if llama.verbose:
3563-
print(
3564-
"Failed to parse function body as JSON schema, falling back to default grammar"
3565-
)
3566-
print(e)
3567-
completion_or_chunks = llama.create_completion(
3568-
prompt=prompt,
3569-
temperature=temperature,
3570-
top_p=top_p,
3571-
top_k=top_k,
3572-
min_p=min_p,
3573-
typical_p=typical_p,
3574-
stream=stream,
3575-
stop=stop,
3576-
max_tokens=max_tokens,
3577-
presence_penalty=presence_penalty,
3578-
frequency_penalty=frequency_penalty,
3579-
repeat_penalty=repeat_penalty,
3580-
tfs_z=tfs_z,
3581-
mirostat_mode=mirostat_mode,
3582-
mirostat_tau=mirostat_tau,
3583-
mirostat_eta=mirostat_eta,
3584-
model=model,
3585-
logits_processor=logits_processor,
3586-
grammar=grammar,
3587-
)
3588-
return _convert_completion_to_chat_function(
3589-
tool_name, completion_or_chunks, stream
3590-
)
3686+
# Ensure there is a system prompt to attach the tool metadata to
3687+
if not any(message["role"] == "system" for message in messages):
3688+
messages = [*messages, {"role": "system", "content": ""}]
35913689

3592-
# Case 3: Automatic tool choice
3593-
assert isinstance(tool_choice, str) and tool_choice == "auto"
3594-
function_names = " | ".join(
3595-
[f'''"functions.{tool['function']['name']}:"''' for tool in tools]
3690+
# Case 2: Automatic or fixed tool choice
3691+
# Case 2 step 1: Determine whether to respond with a message or a tool call
3692+
assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict)
3693+
if isinstance(tool_choice, dict):
3694+
tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]]
3695+
assert tools
3696+
function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools])
3697+
prompt = template_renderer.render(
3698+
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
35963699
)
35973700
initial_gbnf_tool_grammar = (
3598-
"""root ::= functions | "message:"\n"""
3599-
f"""functions ::= {function_names}\n"""
3600-
)
3601-
follow_up_gbnf_tool_grammar = (
3602-
"""root ::= functions | "<|im_end|>"\n"""
3603-
f"""functions ::= {function_names}\n"""
3604-
)
3605-
prompt = template_renderer.render(
3606-
messages=messages,
3607-
tools=tools,
3608-
tool_calls=True,
3609-
add_generation_prompt=True,
3701+
(
3702+
'root ::= "<function_calls>" "\\n" functions | "message:"\n'
3703+
f"functions ::= {function_names}\n"
3704+
)
3705+
if tool_choice == "auto"
3706+
else f'root ::= "<function_calls>" "\\n" functions\nfunctions ::= {function_names}\n'
36103707
)
3611-
completion_or_chunks = llama.create_completion(
3612-
prompt=prompt,
3613-
temperature=0,
3614-
top_p=top_p,
3615-
top_k=top_k,
3616-
min_p=min_p,
3617-
typical_p=typical_p,
3618-
stream=False,
3619-
stop=[":"],
3620-
max_tokens=None,
3621-
presence_penalty=presence_penalty,
3622-
frequency_penalty=frequency_penalty,
3623-
repeat_penalty=repeat_penalty,
3624-
tfs_z=tfs_z,
3625-
mirostat_mode=mirostat_mode,
3626-
mirostat_tau=mirostat_tau,
3627-
mirostat_eta=mirostat_eta,
3628-
model=model,
3629-
logits_processor=logits_processor,
3630-
grammar=llama_grammar.LlamaGrammar.from_string(
3631-
initial_gbnf_tool_grammar, verbose=llama.verbose
3708+
completion = cast(
3709+
llama_types.CreateCompletionResponse,
3710+
llama.create_completion(
3711+
prompt=prompt,
3712+
**{ # type: ignore[arg-type]
3713+
**completion_kwargs,
3714+
"temperature": 0,
3715+
"stream": False,
3716+
"stop": [":"],
3717+
"max_tokens": None,
3718+
"grammar": llama_grammar.LlamaGrammar.from_string(
3719+
initial_gbnf_tool_grammar, verbose=llama.verbose
3720+
),
3721+
},
36323722
),
36333723
)
3634-
completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
36353724
text = completion["choices"][0]["text"]
3636-
if "message" in text:
3725+
tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :]
3726+
3727+
# Case 2 step 2A: Respond with a message
3728+
if tool_name is None:
36373729
return _convert_completion_to_chat(
36383730
llama.create_completion(
36393731
prompt=prompt + "message:\n",
3640-
temperature=temperature,
3641-
top_p=top_p,
3642-
top_k=top_k,
3643-
min_p=min_p,
3644-
typical_p=typical_p,
3645-
stream=stream,
3646-
stop=["<|im_end|>"],
3732+
**completion_kwargs, # type: ignore[arg-type]
36473733
logprobs=top_logprobs if logprobs else None,
3648-
max_tokens=None,
3649-
presence_penalty=presence_penalty,
3650-
frequency_penalty=frequency_penalty,
3651-
repeat_penalty=repeat_penalty,
3652-
tfs_z=tfs_z,
3653-
mirostat_mode=mirostat_mode,
3654-
mirostat_tau=mirostat_tau,
3655-
mirostat_eta=mirostat_eta,
3656-
model=model,
3657-
logits_processor=logits_processor,
3658-
grammar=llama_grammar.LlamaGrammar.from_string(
3659-
follow_up_gbnf_tool_grammar, verbose=llama.verbose
3660-
),
36613734
),
36623735
stream=stream,
36633736
)
36643737

3665-
# One or more function calls
3666-
tool_name = text[len("functions.") :]
3738+
# Case 2 step 2B: One or more function calls
3739+
follow_up_gbnf_tool_grammar = (
3740+
'root ::= functions | "</function_calls>" | "<|im_end|>"\n'
3741+
f"functions ::= {function_names}\n"
3742+
)
3743+
prompt += "<function_calls>\n"
3744+
if stream:
3745+
return _stream_tool_calls(
3746+
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
3747+
)
36673748
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
3668-
if not stream:
3669-
completions: List[llama_types.CreateCompletionResponse] = []
3670-
completions_tool_name: List[str] = []
3671-
while tool is not None:
3672-
prompt += f"functions.{tool_name}:\n"
3673-
try:
3674-
grammar = llama_grammar.LlamaGrammar.from_json_schema(
3675-
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
3676-
)
3677-
except Exception as e:
3678-
grammar = llama_grammar.LlamaGrammar.from_string(
3679-
llama_grammar.JSON_GBNF, verbose=llama.verbose
3680-
)
3681-
if llama.verbose:
3682-
print(
3683-
"Failed to parse function body as JSON schema, falling back to default grammar"
3684-
)
3685-
print(e)
3686-
completion_or_chunks = llama.create_completion(
3687-
prompt=prompt,
3688-
temperature=temperature,
3689-
top_p=top_p,
3690-
top_k=top_k,
3691-
min_p=min_p,
3692-
typical_p=typical_p,
3693-
stream=False,
3694-
stop=stop,
3695-
max_tokens=None,
3696-
presence_penalty=presence_penalty,
3697-
frequency_penalty=frequency_penalty,
3698-
repeat_penalty=repeat_penalty,
3699-
tfs_z=tfs_z,
3700-
mirostat_mode=mirostat_mode,
3701-
mirostat_tau=mirostat_tau,
3702-
mirostat_eta=mirostat_eta,
3703-
model=model,
3704-
logits_processor=logits_processor,
3705-
grammar=grammar,
3706-
)
3707-
completion_or_chunks = cast(
3708-
llama_types.CreateCompletionResponse, completion_or_chunks
3749+
completions: List[llama_types.CreateCompletionResponse] = []
3750+
completions_tool_name: List[str] = []
3751+
while tool is not None:
3752+
# Generate the parameter values for the selected tool
3753+
prompt += f"functions.{tool_name}:\n"
3754+
try:
3755+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
3756+
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
37093757
)
3710-
completions.append(completion_or_chunks)
3711-
completions_tool_name.append(tool_name)
3712-
prompt += completion_or_chunks["choices"][0]["text"]
3713-
prompt += "\n"
3714-
3715-
response = llama.create_completion(
3716-
prompt=prompt,
3717-
temperature=temperature,
3718-
top_p=top_p,
3719-
top_k=top_k,
3720-
min_p=min_p,
3721-
typical_p=typical_p,
3722-
stream=False,
3723-
stop=stop,
3724-
max_tokens=None,
3725-
presence_penalty=presence_penalty,
3726-
frequency_penalty=frequency_penalty,
3727-
repeat_penalty=repeat_penalty,
3728-
tfs_z=tfs_z,
3729-
mirostat_mode=mirostat_mode,
3730-
mirostat_tau=mirostat_tau,
3731-
mirostat_eta=mirostat_eta,
3732-
model=model,
3733-
logits_processor=logits_processor,
3734-
grammar=llama_grammar.LlamaGrammar.from_string(
3735-
follow_up_gbnf_tool_grammar, verbose=llama.verbose
3736-
),
3758+
except Exception as e:
3759+
warnings.warn(
3760+
f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
3761+
category=RuntimeWarning,
3762+
stacklevel=2,
37373763
)
3738-
response = cast(llama_types.CreateCompletionResponse, response)
3739-
3740-
tool_name = response["choices"][0]["text"][len("functions.") :]
3741-
tool = next(
3742-
(tool for tool in tools if tool["function"]["name"] == tool_name), None
3764+
grammar = llama_grammar.LlamaGrammar.from_string(
3765+
llama_grammar.JSON_GBNF, verbose=llama.verbose
37433766
)
3744-
3745-
# Merge completions
3746-
function_call_dict: Union[
3747-
Dict[str, str],
3748-
Dict[
3749-
Literal["function_call"],
3750-
llama_types.ChatCompletionRequestAssistantMessageFunctionCall,
3751-
],
3752-
] = (
3767+
completion_or_chunks = llama.create_completion(
3768+
prompt=prompt,
3769+
**{ # type: ignore[arg-type]
3770+
**completion_kwargs,
3771+
"max_tokens": None,
3772+
"grammar": grammar,
3773+
},
3774+
)
3775+
completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
3776+
completions.append(completion)
3777+
completions_tool_name.append(tool_name)
3778+
prompt += completion["choices"][0]["text"]
3779+
prompt += "\n"
3780+
# Determine whether to call another tool or stop
3781+
response = cast(
3782+
llama_types.CreateCompletionResponse,
3783+
llama.create_completion(
3784+
prompt=prompt,
3785+
**{ # type: ignore[arg-type]
3786+
**completion_kwargs,
3787+
"temperature": 0,
3788+
"stream": False,
3789+
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
3790+
"max_tokens": None,
3791+
"grammar": llama_grammar.LlamaGrammar.from_string(
3792+
follow_up_gbnf_tool_grammar, verbose=llama.verbose
3793+
),
3794+
},
3795+
),
3796+
)
3797+
tool_name = response["choices"][0]["text"][len("functions.") :]
3798+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
3799+
# Merge the completions into a single chat completion
3800+
chat_completion: llama_types.CreateChatCompletionResponse = {
3801+
"id": "chat" + completion["id"],
3802+
"object": "chat.completion",
3803+
"created": completion["created"],
3804+
"model": completion["model"],
3805+
"choices": [
37533806
{
3754-
"function_call": {
3755-
"name": tool_name,
3756-
"arguments": completions[0]["choices"][0]["text"],
3757-
}
3807+
"finish_reason": "tool_calls",
3808+
"index": 0,
3809+
"logprobs": completion["choices"][0]["logprobs"],
3810+
"message": {
3811+
"role": "assistant",
3812+
"content": None,
3813+
"tool_calls": [
3814+
{
3815+
"id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"],
3816+
"type": "function",
3817+
"function": {
3818+
"name": tool_name,
3819+
"arguments": completion["choices"][0]["text"],
3820+
},
3821+
}
3822+
for i, (tool_name, completion) in enumerate(
3823+
zip(completions_tool_name, completions, strict=True)
3824+
)
3825+
],
3826+
},
37583827
}
3759-
if len(completions) == 1
3760-
else {}
3761-
)
3762-
return {
3763-
"id": "chat" + completion["id"],
3764-
"object": "chat.completion",
3765-
"created": completion["created"],
3766-
"model": completion["model"],
3767-
"choices": [
3768-
{
3769-
"finish_reason": "tool_calls",
3770-
"index": 0,
3771-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
3772-
"message": {
3773-
"role": "assistant",
3774-
"content": None,
3775-
"tool_calls": [
3776-
{
3777-
"id": "call_"
3778-
+ f"_{i}_"
3779-
+ tool_name
3780-
+ "_"
3781-
+ completion["id"],
3782-
"type": "function",
3783-
"function": {
3784-
"name": tool_name,
3785-
"arguments": completion["choices"][0]["text"],
3786-
},
3787-
}
3788-
for i, (tool_name, completion) in enumerate(
3789-
zip(completions_tool_name, completions)
3790-
)
3791-
],
3792-
**function_call_dict,
3793-
},
3794-
}
3795-
],
3796-
"usage": {
3797-
"completion_tokens": sum(
3798-
(
3799-
completion["usage"]["completion_tokens"]
3800-
if "usage" in completion
3801-
else 0
3802-
)
3803-
for completion in completions
3804-
),
3805-
"prompt_tokens": sum(
3806-
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
3807-
for completion in completions
3808-
),
3809-
"total_tokens": sum(
3810-
completion["usage"]["total_tokens"] if "usage" in completion else 0
3811-
for completion in completions
3812-
),
3813-
},
3828+
],
3829+
"usage": {
3830+
"completion_tokens": sum(
3831+
(completion["usage"]["completion_tokens"] if "usage" in completion else 0)
3832+
for completion in completions
3833+
),
3834+
"prompt_tokens": sum(
3835+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
3836+
for completion in completions
3837+
),
3838+
"total_tokens": sum(
3839+
completion["usage"]["total_tokens"] if "usage" in completion else 0
3840+
for completion in completions
3841+
),
3842+
},
3843+
}
3844+
if len(completions) == 1:
3845+
single_function_call: llama_types.ChatCompletionResponseFunctionCall = {
3846+
"name": tool_name,
3847+
"arguments": completions[0]["choices"][0]["text"],
38143848
}
3815-
3816-
raise ValueError("Automatic streaming tool choice is not supported")
3849+
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
3850+
return chat_completion

‎pyproject.toml‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ test = [
4545
"sse-starlette>=1.6.1",
4646
"starlette-context>=0.3.6,<0.4",
4747
"pydantic-settings>=2.0.1",
48-
"huggingface-hub>=0.23.0"
48+
"huggingface-hub>=0.23.0",
49+
"typeguard>=4.2.1",
4950
]
5051
dev = [
5152
"black>=23.3.0",

‎tests/test_llama_chat_format.py‎

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
import json
2+
import os
3+
from collections.abc import Iterator
4+
from typing import cast
25

6+
import pytest
37
import jinja2
8+
from typeguard import ForwardRefPolicy, check_type
49

510
from llama_cpp import (
611
ChatCompletionRequestUserMessage,
12+
Llama,
13+
llama_chat_format,
14+
llama_supports_gpu_offload,
15+
llama_types
716
)
8-
import llama_cpp.llama_types as llama_types
9-
import llama_cpp.llama_chat_format as llama_chat_format
10-
1117
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
18+
from llama_cpp.llama_types import (
19+
ChatCompletionRequestMessage,
20+
ChatCompletionTool,
21+
ChatCompletionToolChoiceOption,
22+
CreateChatCompletionResponse,
23+
CreateChatCompletionStreamResponse,
24+
)
25+
1226

1327
def test_mistral_instruct():
1428
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
@@ -87,3 +101,114 @@ def test_hf_tokenizer_config_str_to_chat_formatter():
87101
)
88102

89103
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "")
104+
105+
106+
def is_accelerator_available() -> bool:
107+
"""Check if an accelerator is available."""
108+
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8
109+
110+
111+
@pytest.mark.parametrize(
112+
"stream",
113+
[
114+
pytest.param(True, id="stream=True"),
115+
pytest.param(False, id="stream=False"),
116+
],
117+
)
118+
@pytest.mark.parametrize(
119+
"tool_choice",
120+
[
121+
pytest.param("none", id="tool_choice=none"),
122+
pytest.param("auto", id="tool_choice=auto"),
123+
pytest.param(
124+
{"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed"
125+
),
126+
],
127+
)
128+
@pytest.mark.parametrize(
129+
"user_prompt_expected_tool_calls",
130+
[
131+
pytest.param(
132+
("Is 7 a prime number?", 0),
133+
id="expected_tool_calls=0",
134+
),
135+
pytest.param(
136+
("What's the weather like in Paris today?", 1),
137+
id="expected_tool_calls=1",
138+
),
139+
pytest.param(
140+
("What's the weather like in Paris today? What about New York?", 2),
141+
id="expected_tool_calls=2",
142+
),
143+
],
144+
)
145+
@pytest.mark.parametrize(
146+
"llm_repo_id",
147+
[
148+
pytest.param("bartowski/Llama-3.2-3B-Instruct-GGUF", id="llama_3.2_3B"),
149+
pytest.param(
150+
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
151+
id="llama_3.1_8B",
152+
marks=pytest.mark.skipif(
153+
not is_accelerator_available(), reason="Accelerator not available"
154+
),
155+
),
156+
],
157+
)
158+
def test_llama_cpp_python_tool_use(
159+
llm_repo_id: str,
160+
user_prompt_expected_tool_calls: tuple[str, int],
161+
tool_choice: ChatCompletionToolChoiceOption,
162+
stream: bool,
163+
) -> None:
164+
"""Test the upgraded chatml-function-calling llama-cpp-python chat handler."""
165+
user_prompt, expected_tool_calls = user_prompt_expected_tool_calls
166+
if isinstance(tool_choice, dict) and expected_tool_calls == 0:
167+
pytest.skip("Nonsensical")
168+
llm = Llama.from_pretrained(
169+
repo_id=llm_repo_id,
170+
filename="*Q4_K_M.gguf",
171+
n_ctx=4096,
172+
n_gpu_layers=-1,
173+
verbose=False,
174+
chat_format="chatml-function-calling",
175+
)
176+
messages: list[ChatCompletionRequestMessage] = [{"role": "user", "content": user_prompt}]
177+
tools: list[ChatCompletionTool] = [
178+
{
179+
"type": "function",
180+
"function": {
181+
"name": "get_weather",
182+
"description": "Get the weather for a location.",
183+
"parameters": {
184+
"type": "object",
185+
"properties": {"location": {"type": "string", "description": "A city name."}},
186+
},
187+
},
188+
}
189+
]
190+
response = llm.create_chat_completion(
191+
messages=messages, tools=tools, tool_choice=tool_choice, stream=stream
192+
)
193+
if stream:
194+
response = cast(Iterator[CreateChatCompletionStreamResponse], response)
195+
num_tool_calls = 0
196+
for chunk in response:
197+
check_type(chunk, CreateChatCompletionStreamResponse)
198+
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
199+
if isinstance(tool_calls, list):
200+
num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1
201+
assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0)
202+
else:
203+
response = cast(CreateChatCompletionResponse, response)
204+
check_type(
205+
response, CreateChatCompletionResponse, forward_ref_policy=ForwardRefPolicy.IGNORE
206+
)
207+
if expected_tool_calls == 0 or tool_choice == "none":
208+
assert response["choices"][0]["message"].get("tool_calls") is None
209+
else:
210+
assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls
211+
assert all(
212+
tool_call["function"]["name"] == tools[0]["function"]["name"]
213+
for tool_call in response["choices"][0]["message"]["tool_calls"]
214+
)

0 commit comments

Comments
 (0)
Please sign in to comment.