Skip to content

Commit 29de071

Browse files
committed
fix: only send function name on first stream event
1 parent 3b09662 commit 29de071

File tree

7 files changed

+48
-12
lines changed

7 files changed

+48
-12
lines changed

integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"logprobs": null
2121
}
2222
],
23-
"created": 1739799458,
23+
"created": 1739910558,
2424
"id": "",
2525
"model": "meta-llama/Llama-3.1-8B-Instruct",
2626
"object": "chat.completion.chunk",

integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"logprobs": null
2121
}
2222
],
23-
"created": 1739797595,
23+
"created": 1739910826,
2424
"id": "",
2525
"model": "meta-llama/Llama-3.1-8B-Instruct",
2626
"object": "chat.completion.chunk",

integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"logprobs": null
2222
}
2323
],
24-
"created": 1739456930,
24+
"created": 1739910816,
2525
"id": "",
2626
"model": "meta-llama/Llama-3.1-8B-Instruct",
2727
"object": "chat.completion.chunk",

integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"logprobs": null
2222
}
2323
],
24-
"created": 1739367874,
24+
"created": 1739910803,
2525
"id": "",
2626
"model": "meta-llama/Llama-3.1-8B-Instruct",
2727
"object": "chat.completion.chunk",

integration-tests/models/test_openai_llama_tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
101101

102102
tool_call_string = ""
103103
for chunk in chat_completion:
104-
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
104+
function_call = chunk.choices[0].delta.tool_calls[0].function
105+
if function_call:
106+
tool_call_string += function_call.arguments
105107
last_chunk = chunk.to_dict()
106108

107109
assert (

integration-tests/models/test_tools_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async def test_flash_llama_grammar_tools_stream(
216216
assert response.choices[0].delta.content is None
217217

218218
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
219-
assert count == 16
219+
assert count == 17
220220
assert last_response == response_snapshot
221221

222222

@@ -360,7 +360,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
360360
)
361361
last_response = response
362362

363-
assert count == 23
363+
assert count == 24
364364
assert (
365365
tool_calls_generated
366366
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
@@ -458,7 +458,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
458458
tool_calls_generated += tool_call["function"]["arguments"]
459459
last_response = response
460460

461-
assert count == 25
461+
assert count == 26
462462
assert (
463463
tool_calls_generated
464464
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'

router/src/server.rs

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,40 @@ pub(crate) async fn chat_completions(
13051305
state = StreamState::Content {
13061306
skip_close_quote: false,
13071307
};
1308+
let event = Event::default();
1309+
let current_time = std::time::SystemTime::now()
1310+
.duration_since(std::time::UNIX_EPOCH)
1311+
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
1312+
.as_secs();
1313+
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
1314+
role: "assistant".to_string(),
1315+
tool_calls: vec![DeltaToolCall {
1316+
index: 0,
1317+
id: String::new(),
1318+
r#type: "function".to_string(),
1319+
function: Function {
1320+
name: Some(global_function_name.clone()),
1321+
arguments: "".to_string(),
1322+
},
1323+
}],
1324+
});
1325+
let chat_complete =
1326+
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
1327+
id: String::new(),
1328+
created: current_time,
1329+
model: model_id.clone(),
1330+
system_fingerprint: system_fingerprint.clone(),
1331+
choices: vec![ChatCompletionChoice {
1332+
index: 0,
1333+
delta: tool_delta_start,
1334+
logprobs: None,
1335+
finish_reason: None,
1336+
}],
1337+
usage: None,
1338+
});
1339+
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
1340+
InferError::StreamSerializationError(e.to_string()).into()
1341+
}));
13081342
buffer.drain(1..); // only keep the first token (opening '{')
13091343
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
13101344
}
@@ -1341,7 +1375,7 @@ pub(crate) async fn chat_completions(
13411375
None,
13421376
None,
13431377
None,
1344-
Some(global_function_name.clone()),
1378+
None,
13451379
));
13461380
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
13471381
InferError::StreamSerializationError(e.to_string()).into()
@@ -1370,7 +1404,7 @@ pub(crate) async fn chat_completions(
13701404
response_as_tool,
13711405
system_fingerprint.clone(),
13721406
model_id.clone(),
1373-
Some(global_function_name.clone()),
1407+
None,
13741408
);
13751409

13761410
yield Ok::<Event, Infallible>(event);
@@ -1394,7 +1428,7 @@ pub(crate) async fn chat_completions(
13941428
response_as_tool,
13951429
system_fingerprint.clone(),
13961430
model_id.clone(),
1397-
Some(global_function_name.clone()),
1431+
None,
13981432
);
13991433
yield Ok::<Event, Infallible>(event);
14001434
} else {
@@ -1407,7 +1441,7 @@ pub(crate) async fn chat_completions(
14071441
response_as_tool,
14081442
system_fingerprint.clone(),
14091443
model_id.clone(),
1410-
Some(global_function_name.clone()),
1444+
None,
14111445
);
14121446
yield Ok::<Event, Infallible>(event);
14131447
}

0 commit comments

Comments
 (0)