Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General fixes for tool calling #2954

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,26 @@
}
}
},
"FunctionDefinition": {
"FunctionCall": {
"type": "object",
"required": [
"name",
"arguments"
],
"properties": {
"arguments": {
"type": "string"
},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionDefinitionDeprecated": {
"type": "object",
"required": [
"name",
Expand All @@ -1525,6 +1544,23 @@
}
}
},
"FunctionDefinition": {
"type": "object",
"required": [
"name",
"parameters"
],
"properties": {
"parameters": {},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionName": {
"type": "object",
"required": [
Expand Down Expand Up @@ -2227,7 +2263,14 @@
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionDefinition"
"oneOf": [
{
"$ref": "#/components/schemas/FunctionDefinition"
},
{
"$ref": "#/components/schemas/FunctionDefinitionDeprecated"
}
]
},
"type": {
"type": "string",
Expand All @@ -2244,7 +2287,7 @@
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionDefinition"
"$ref": "#/components/schemas/FunctionCall"
},
"id": {
"type": "string"
Expand Down Expand Up @@ -2370,4 +2413,4 @@
"description": "Hugging Face Text Generation Inference API"
}
]
}
}
2 changes: 1 addition & 1 deletion docs/source/basic_tutorials/using_guidance.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ chat = client.chat_completion(
)

print(chat.choices[0].message.tool_calls)
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]

```

Expand Down
41 changes: 39 additions & 2 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,44 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}

#[test]
fn test_chat_template_with_default_tool_template_arguments_deprecated() {
let ct = ChatTemplate::new(
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);

// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
Message {
name: None,
role: "assistant".to_string(),
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText("Just testing".to_string()),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}

Expand Down Expand Up @@ -937,7 +974,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}
4 changes: 2 additions & 2 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
parameters: json!({
"type": "object",
"properties": {
"content": {
Expand Down Expand Up @@ -83,7 +83,7 @@ impl ToolGrammar {
}),
);

if let Value::Object(args) = func.arguments {
if let Value::Object(args) = func.parameters {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}
Expand Down
37 changes: 20 additions & 17 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ pub(crate) struct ChatCompletionChoice {
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
tool_calls: DeltaToolCall,
tool_calls: Vec<DeltaToolCall>,
}

#[derive(Clone, Debug, Serialize, ToSchema)]
Expand All @@ -742,11 +742,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
pub function: FunctionCallChunk,
}

#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function {
pub(crate) struct FunctionCallChunk {
pub name: Option<String>,
pub arguments: String,
}
Expand All @@ -757,7 +757,7 @@ impl ChatCompletionChunk {
model: String,
system_fingerprint: String,
delta: Option<String>,
tool_calls: Option<Vec<String>>,
tool_calls: Option<FunctionCallChunk>,
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
Expand All @@ -770,15 +770,12 @@ impl ChatCompletionChunk {
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: DeltaToolCall {
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
},
function: tool_calls,
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
Expand Down Expand Up @@ -1133,8 +1130,14 @@ pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "parameters")]
pub arguments: serde_json::Value,
#[serde(alias = "arguments")]
pub parameters: serde_json::Value,
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
pub name: String,
pub arguments: String,
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
Expand All @@ -1160,7 +1163,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub(crate) struct ToolCall {
pub id: String,
pub r#type: String,
pub function: FunctionDefinition,
pub function: FunctionCall,
}

#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
Expand Down Expand Up @@ -1679,19 +1682,19 @@ mod tests {
tool_calls: vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
function: FunctionCall {
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
}),
})
.to_string(),
},
}],
});
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"#
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}

Expand Down
Loading