Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Adding the "User Defined Custom Tool Calling" parser for the Llama models #12752

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = false %}
{%- endif %}
{%- if not date_string is defined %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}

{#- System message + builtin tools #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if builtin_tools is defined or tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{%- if builtin_tools is defined %}
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}

{%- if builtin_tools is defined %}
{{- "# Tool Instructions\n"}}
{{- "- Always execute python code in messages that you share.\n"}}
{{- "- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n"}}
{%- endif %}

{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions:\n\n"}}

{%- for t in tools %}
{%- if t.function is defined %}
{%- set t = t.function %}
{%- endif -%}
{{- "Use the function '"+t.name+"' to: "+t.description+"\n"}}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- "If a you choose to call a function ONLY reply in the following format:\n"}}
{{- "<{start_tag}={function_name}>{parameters}{end_tag}\n" }}
{{- "where\n\n"}}
{{- "start_tag => `<function`\n" }}
{{- "parameters => a JSON dict with the function argument name as key and function argument value as value.\n"}}
{{- "end_tag => `</function>`" }}
{{- "\n\n" }}
{{- "Here is an example,\n"}}
{{- "<function=example_function_name>{\"example_name\": \"example_value\"}</function>"}}
{{- "\n\n" }}
{{- "Reminder:\n"}}
{{- "- Function calls MUST follow the specified format\n"}}
{{- "- Required parameters MUST be specified\n"}}
{{- "- Only call one function at a time\n"}}
{{- "- Put the entire function call reply on one line\n"}}
{{- "- Always use the information returned by the function to answer to the user\n"}}
{{- "- If there is no relevant function available, do NOT call any function: respond directly to the user\n\n"}}

{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}

{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}

{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
{%- for arg_name, arg_val in tool_call.arguments | items %}
{{- arg_name + '="' + arg_val + '"' }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}
{%- else %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- '<function=' + tool_call.name + '>' + tool_call.arguments + '</function>'}}
{%- endif %}
{%- if builtin_tools is defined or tools is not none%}
{#- This means we're in ipython mode #}
{{- "<|eom_id|>" }}
{%- else %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
from .pythonic_tool_parser import PythonicToolParser
from .llama_usr_defined_tool_parser import Llama3UserDefinedCustomToolParser

__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"PythonicToolParser"
"PythonicToolParser", "Llama3UserDefinedCustomToolParser"
]
248 changes: 248 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0

import json
import re
from typing import Dict, List, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

def _count_substring(string, substring):
"""
Counts the number of non-overlapping occurrences of a substring in a string.

Check failure on line 26 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:26:81: E501 Line too long (84 > 80)
Args:
string (str): The string to search in.
substring (str): The substring to search for.

Returns:
int: The number of non-overlapping occurrences of the substring in the string.
"""

Check failure on line 33 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:33:81: E501 Line too long (90 > 80)
count = 0
start = 0
while True:
start = string.find(substring, start)
if start == -1:
break
count += 1
start += len(substring)
return count

@ToolParserManager.register_module("llama3_user_defined_custom")
class Llama3UserDefinedCustomToolParser(ToolParser):

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Llama model")
self.model_tokenizer = self.model_tokenizer.tokenizer

self.prev_tool_call_arr: List[Dict] = []
self.streamed_args_for_tool: List[str] = []
self.is_parsing_toolcall = False

self.nb_tool_calls = 0
self.current_tool_name=""
self.current_tool_call_uuid=""
self.is_current_tool_name_sent = False
self.tool_call_start_token: str = "<function"
self.tool_call_precall_token: str = '>{"'
self.tool_call_end_token: str = "</function>"
self.bot_token = "<|python_tag|>"

self.tool_call_start_token_id = tokenizer.encode(self.tool_call_start_token,
add_special_tokens=False)

self.tool_call_end_token_id = tokenizer.encode(self.tool_call_end_token,
add_special_tokens=False)

self.tool_call_preargs_token_id = tokenizer.encode(self.tool_call_precall_token,
add_special_tokens=False)

self.bot_token_id = tokenizer.encode(self.bot_token,
add_special_tokens=False)

self.tool_call_regex = re.compile(r"<function=([^>]+)>\{([^}]+)\}(?:</function>|>)?")

if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")

def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:

# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(model_output)

logger.info("function_call_tuples: %s", function_call_tuples)
print("function_call_tuples: %s", function_call_tuples)

# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
{
"name":match[0],
"arguments":json.loads("{"+match[1]+"}")
}
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False)))
for function_call in raw_function_calls
]

content = model_output[:model_output.
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)

except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)


def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Extract tool calls from a streaming response.
Handles format: <function=functionName{arguments}>
Returns DeltaMessage with either tool_calls or content.
"""
logger.debug("\n" + "="*50)
logger.debug("STREAMING FUNCTION CALLED")

Check failure on line 161 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G003)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:161:22: G003 Logging statement uses `+`
logger.debug("Tool call start token id IDs:", self.tool_call_start_token_id)
logger.debug("Tool call precall token id IDs:", self.tool_call_preargs_token_id)
logger.debug("Tool call end token id IDs:", self.tool_call_end_token_id)
logger.debug("Previous text:", previous_text)
logger.debug("Current text:", current_text)
logger.debug("Delta text:", delta_text)
logger.debug("Previous token IDs:", previous_token_ids)
logger.debug("Current token IDs:", current_token_ids)
logger.debug("Delta token IDs:", delta_token_ids)
logger.debug("Current tool name sent:", self.is_current_tool_name_sent)
logger.debug("-"*50 + "\n")
flags = Allow.ALL if self.is_current_tool_name_sent \
else Allow.ALL & ~Allow.STR

logger.debug(f"{delta_token_ids[0] in self.tool_call_start_token_id=}")

Check failure on line 176 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G003)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:176:22: G003 Logging statement uses `+`
if delta_token_ids[0] in self.tool_call_start_token_id :
# We possibly have a tool call (not sure yet) we don't stream

logger.debug(f"{_count_substring(current_text,self.tool_call_start_token)=}")

Check failure on line 180 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:180:22: G004 Logging statement uses f-string
if _count_substring(current_text,self.tool_call_start_token) > self.nb_tool_calls \
and not self.is_parsing_toolcall :

self.is_parsing_toolcall=True
self.nb_tool_calls +=1 #will serve as id

Check failure on line 185 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:185:17: G004 Logging statement uses f-string
self.current_tool_call_uuid = random_uuid()
logger.debug("New tool call detected, id:", self.nb_tool_calls-1)

Check failure on line 187 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:187:81: E501 Line too long (95 > 80)
return None # going to the next iter
else :
logger.debug("Tool call already parsed, id:", self.nb_tool_calls-1)

if self.is_parsing_toolcall and not self.is_current_tool_name_sent :
logger.debug("Parsing tool call, id:", self.nb_tool_calls-1)
# We are parsing a tool call, we need to parse the tool name
if delta_token_ids != self.tool_call_preargs_token_id:
self.current_tool_name += delta_text
logger.debug(f"{self.current_tool_name=}")
return None # moving on to the next iteration
else :
self.current_tool_name = self.current_tool_name.lstrip('=')
self.is_current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.nb_tool_calls - 1,
type="function",
id=f"chatcmpl-tool-{self.current_tool_call_uuid}",

Check failure on line 205 in vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py:205:30: G004 Logging statement uses f-string
function=DeltaFunctionCall(
name=self.current_tool_name))
])

if self.is_current_tool_name_sent :
logger.debug("Parsed tool name : ", self.current_tool_name)

if _count_substring(current_text,self.tool_call_end_token) < self.nb_tool_calls:
self.streamed_args_for_tool.append(delta_text)
return None # moving on to the next iteration
else :
arguments = '{"'+''.join(self.streamed_args_for_tool) # adding back {" at the beginning for valid JSON
arguments = arguments.rstrip(self.tool_call_end_token) # removing the end token
logger.debug("Concatenated tool call arguments : ", arguments)

current_tool_args = partial_json_parser.loads(
arguments or "{}",
flags) if self.streamed_args_for_tool else None

logger.debug("Parsed tool call arguments : ", current_tool_args)


delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.nb_tool_calls - 1,
type="function",
id=f"chatcmpl-tool-{self.current_tool_call_uuid}",
function=DeltaFunctionCall(
name=self.current_tool_name,
arguments=json.dumps(current_tool_args)))
])

self.reset_state()

return delta
else :
logger.debug("No tool call detected, returning just text : ", delta_text)
return DeltaMessage(content=delta_text)

def reset_state(self):
self.current_tool_name = ''
self.is_parsing_toolcall=False
self.is_current_tool_name_sent = False
self.streamed_args_for_tool = []
Loading