Skip to content

Commit a54ed80

Browse files
[Model] Add mistral function calling format to all models loaded with "mistral" format (vllm-project#8515)
Co-authored-by: Cyrus Leung <[email protected]>
1 parent 9855b99 commit a54ed80

File tree

5 files changed

+219
-9
lines changed

5 files changed

+219
-9
lines changed

examples/offline_chat_with_tools.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# ruff: noqa
2+
import json
3+
import random
4+
import string
5+
6+
from vllm import LLM
7+
from vllm.sampling_params import SamplingParams
8+
9+
# This script is an offline demo for function calling
10+
#
11+
# If you want to run a server/client setup, please follow this code:
12+
#
13+
# - Server:
14+
#
15+
# ```bash
16+
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
17+
# ```
18+
#
19+
# - Client:
20+
#
21+
# ```bash
22+
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
23+
# --header 'Content-Type: application/json' \
24+
# --header 'Authorization: Bearer token' \
25+
# --data '{
26+
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
27+
# "messages": [
28+
# {
29+
# "role": "user",
30+
# "content": [
31+
# {"type" : "text", "text": "Describe this image in detail please."},
32+
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
33+
# {"type" : "text", "text": "and this one as well. Answer in French."},
34+
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
35+
# ]
36+
# }
37+
# ]
38+
# }'
39+
# ```
40+
#
41+
# Usage:
42+
# python demo.py simple
43+
# python demo.py advanced
44+
45+
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
46+
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
47+
# or "mistralai/Mistral-Large-Instruct-2407"
48+
# or any other mistral model with function calling ability
49+
50+
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
51+
llm = LLM(model=model_name,
52+
tokenizer_mode="mistral",
53+
config_format="mistral",
54+
load_format="mistral")
55+
56+
57+
def generate_random_id(length=9):
58+
characters = string.ascii_letters + string.digits
59+
random_id = ''.join(random.choice(characters) for _ in range(length))
60+
return random_id
61+
62+
63+
# simulate an API that can be called
64+
def get_current_weather(city: str, state: str, unit: 'str'):
65+
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
66+
"partly cloudly, with highs in the 90's.")
67+
68+
69+
tool_funtions = {"get_current_weather": get_current_weather}
70+
71+
tools = [{
72+
"type": "function",
73+
"function": {
74+
"name": "get_current_weather",
75+
"description": "Get the current weather in a given location",
76+
"parameters": {
77+
"type": "object",
78+
"properties": {
79+
"city": {
80+
"type":
81+
"string",
82+
"description":
83+
"The city to find the weather for, e.g. 'San Francisco'"
84+
},
85+
"state": {
86+
"type":
87+
"string",
88+
"description":
89+
"the two-letter abbreviation for the state that the city is"
90+
" in, e.g. 'CA' which would mean 'California'"
91+
},
92+
"unit": {
93+
"type": "string",
94+
"description": "The unit to fetch the temperature in",
95+
"enum": ["celsius", "fahrenheit"]
96+
}
97+
},
98+
"required": ["city", "state", "unit"]
99+
}
100+
}
101+
}]
102+
103+
messages = [{
104+
"role":
105+
"user",
106+
"content":
107+
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
108+
}]
109+
110+
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
111+
output = outputs[0].outputs[0].text.strip()
112+
113+
# append the assistant message
114+
messages.append({
115+
"role": "assistant",
116+
"content": output,
117+
})
118+
119+
# let's now actually parse and execute the model's output simulating an API call by using the
120+
# above defined function
121+
tool_calls = json.loads(output)
122+
tool_answers = [
123+
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
124+
]
125+
126+
# append the answer as a tool message and let the LLM give you an answer
127+
messages.append({
128+
"role": "tool",
129+
"content": "\n\n".join(tool_answers),
130+
"tool_call_id": generate_random_id(),
131+
})
132+
133+
outputs = llm.chat(messages, sampling_params, tools=tools)
134+
135+
print(outputs[0].outputs[0].text.strip())
136+
# yields
137+
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
138+
# 'It is partly cloudly, with highs in the 90's.'

tests/models/decoder_only/language/test_mistral.py

+67
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,61 @@
44
"""
55
import pytest
66

7+
from vllm import SamplingParams
8+
79
from ...utils import check_logprobs_close
810

911
MODELS = [
1012
"mistralai/Mistral-7B-Instruct-v0.1",
1113
"mistralai/Mistral-7B-Instruct-v0.3",
14+
# Mistral-Nemo is to big for CI, but passes locally
15+
# "mistralai/Mistral-Nemo-Instruct-2407"
1216
]
1317

18+
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
19+
20+
# for function calling
21+
TOOLS = [{
22+
"type": "function",
23+
"function": {
24+
"name": "get_current_weather",
25+
"description": "Get the current weather in a given location",
26+
"parameters": {
27+
"type": "object",
28+
"properties": {
29+
"city": {
30+
"type":
31+
"string",
32+
"description":
33+
"The city to find the weather for, e.g. 'San Francisco'"
34+
},
35+
"state": {
36+
"type":
37+
"string",
38+
"description":
39+
"the two-letter abbreviation for the state that the city is"
40+
" in, e.g. 'CA' which would mean 'California'"
41+
},
42+
"unit": {
43+
"type": "string",
44+
"description": "The unit to fetch the temperature in",
45+
"enum": ["celsius", "fahrenheit"]
46+
}
47+
},
48+
"required": ["city", "state", "unit"]
49+
}
50+
}
51+
}]
52+
MSGS = [{
53+
"role":
54+
"user",
55+
"content": ("Can you tell me what the temperate"
56+
" will be in Dallas, in fahrenheit?")
57+
}]
58+
EXPECTED_FUNC_CALL = (
59+
'[{"name": "get_current_weather", "arguments": '
60+
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
61+
1462

1563
@pytest.mark.parametrize("model", MODELS)
1664
@pytest.mark.parametrize("dtype", ["bfloat16"])
@@ -81,3 +129,22 @@ def test_mistral_format(
81129
name_0="hf",
82130
name_1="mistral",
83131
)
132+
133+
134+
@pytest.mark.parametrize("dtype", ["bfloat16"])
135+
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
136+
def test_mistral_function_calling(
137+
vllm_runner,
138+
model: str,
139+
dtype: str,
140+
) -> None:
141+
with vllm_runner(model,
142+
dtype=dtype,
143+
tokenizer_mode="mistral",
144+
config_format="mistral",
145+
load_format="mistral") as vllm_model:
146+
outputs = vllm_model.model.chat(MSGS,
147+
tools=TOOLS,
148+
sampling_params=SAMPLING_PARAMS)
149+
150+
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL

vllm/entrypoints/llm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
2-
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
2+
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
3+
overload)
34

45
from tqdm import tqdm
56

@@ -357,6 +358,7 @@ def chat(
357358
lora_request: Optional[LoRARequest] = None,
358359
chat_template: Optional[str] = None,
359360
add_generation_prompt: bool = True,
361+
tools: Optional[List[Dict[str, Any]]] = None,
360362
) -> List[RequestOutput]:
361363
"""
362364
Generate responses for a chat conversation.
@@ -401,13 +403,15 @@ def chat(
401403
messages=messages,
402404
chat_template=chat_template,
403405
add_generation_prompt=add_generation_prompt,
406+
tools=tools,
404407
)
405408
else:
406409
prompt = apply_hf_chat_template(
407410
tokenizer,
408411
conversation=conversation,
409412
chat_template=chat_template,
410413
add_generation_prompt=add_generation_prompt,
414+
tools=tools,
411415
)
412416

413417
inputs: PromptInputs

vllm/entrypoints/openai/serving_chat.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ async def create_chat_completion(
123123
]
124124

125125
prompt: Union[str, List[int]]
126-
if isinstance(tokenizer, MistralTokenizer):
126+
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
127+
if is_mistral_tokenizer:
127128
prompt = apply_mistral_chat_template(
128129
tokenizer,
129130
messages=request.messages,
@@ -159,10 +160,10 @@ async def create_chat_completion(
159160
return self.create_error_response(
160161
"tool_choice = \"required\" is not supported!")
161162

162-
# "auto" tools requires --enable-auto-tool-choice
163-
# and --tool-call-parser
164-
if request.tool_choice == "auto" and not (
163+
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
165164
self.enable_auto_tools and self.tool_parser is not None):
165+
# for hf tokenizers, "auto" tools requires
166+
# --enable-auto-tool-choice and --tool-call-parser
166167
return self.create_error_response(
167168
"\"auto\" tool choice requires "
168169
"--enable-auto-tool-choice and --tool-call-parser to be set")

vllm/transformers_utils/tokenizers/mistral.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,18 @@ def apply_chat_template(self,
165165
messages: List["ChatCompletionMessageParam"],
166166
tools: Optional[Dict[str, Any]] = None,
167167
**kwargs) -> List[int]:
168-
assert tools is None, "`tools` are not yet supported."
169168

170-
request = ChatCompletionRequest(
171-
messages=messages) # type: ignore[type-var]
169+
request = ChatCompletionRequest(messages=messages,
170+
tools=tools) # type: ignore[type-var]
172171
encoded = self.mistral.encode_chat_completion(request)
173172

174173
# encode-decode to get clean prompt
175174
return encoded.tokens
176175

177176
def convert_tokens_to_string(self, tokens: List[str]) -> str:
178177
if isinstance(self.tokenizer, Tekkenizer):
179-
return "".join(tokens)
178+
return "".join(t for t in tokens
179+
if t not in self.tokenizer._all_special_tokens)
180180
else:
181181
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
182182

0 commit comments

Comments
 (0)