-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral
#12332
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Rafael Vasquez <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking at this @rafvasq! I know it is still a WIP, but I left some suggestions on the current state of the PR.
vllm/utils.py
Outdated
@@ -2206,3 +2211,8 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], | |||
else: | |||
func = partial(method, obj) # type: ignore | |||
return func(*args, **kwargs) | |||
|
|||
def generate_valid_mistral_tool_id(): | |||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mistral requirement is exactly 9 characters.
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. | |
# Mistral Tool Call Ids must be alphanumeric with a length of 9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this comment in a couple of places.
if isinstance(tokenizer, MistralTokenizer): | ||
for tool_call in message.tool_calls: | ||
tool_call.id = generate_valid_mistral_tool_id() | ||
logger.warning(f"Assigned new tool_id: {tool_call.id} for tool: {tool_call}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the tool_choice: auto
case, the MistralTokenizer should be generating valid tool ids itself and not need this to override it. So I think this logic can be moved to the branch of the condition above that checks for type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
since that is the case where a tool_id is generated that does not meet the conditions for Mistral.
@@ -61,6 +61,8 @@ def maybe_serialize_tool_calls(request: ChatCompletionRequest): | |||
while True: | |||
try: | |||
tool_call = next(tool_calls_validator) # type: ignore | |||
tool_call['id'] = generate_valid_mistral_tool_id() | |||
logger.warning(f"Assigned new tool_id: {tool_call['id']} for tool: {tool_call}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this change will modify the tool_id of every tool_call message in the incoming request with a new random id such that the id will change with each subsequent step in the conversation. I don't think we want that to happen.
Instead, tool_ids that are valid can be passed-through as-is and tool_ids that are not valid for mistral should have a consistent mapping to one that is valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. I added a check here (using mistral common's validation) so that the id will only be generated/changed if it's invalid.
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
vllm/utils.py
Outdated
def generate_valid_mistral_tool_id(): | ||
# Mistral Tool Call Ids must be alphanumeric with a length of 9. | ||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 | ||
return "".join(choices(ALPHANUMERIC, k=9)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, we can just use the static function MistralToolCall.generate_random_id()
instead of writing a new function.
@@ -668,6 +669,10 @@ async def chat_completion_full_generator( | |||
arguments=output.text)) | |||
]) | |||
|
|||
if isinstance(tokenizer, MistralTokenizer): | |||
for tool_call in message.tool_calls: | |||
tool_call.id = generate_valid_mistral_tool_id() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that there is a MistralToolCall
class that overrides the id generation. It would be a bit cleaner to just use that class, eg. could make this change above:
tool_call_class = MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
@@ -62,6 +62,8 @@ def maybe_serialize_tool_calls(request: ChatCompletionRequest): | |||
try: | |||
tool_call = next(tool_calls_validator) # type: ignore | |||
validated_tool_calls.append(tool_call) | |||
if not re.match(r"^[a-zA-Z0-9]{9}$", tool_call['id']): | |||
tool_call['id'] = generate_valid_mistral_tool_id() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will update the tool_call
entries id
to a valid id, but if the chat history also includes a tool response message with an id
, this code will not adjust it, eg.
"messages":[
{
"role": "user",
"content": "What is the weather in Dallas Texas?"
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "chatcmpl-asdf",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{ \"city\": \"Dallas\", \"state\": \"TX\", \"unit\": \"celsius\" }"
}
}
]
},
{
"role": "tool",
"tool_call_id": "chatcmpl-asdf",
"name": "get_current_weather",
"content": "90 degrees, partly cloudy"
}
],
Also, we'll need to make sure that the id is mapped to the same corrected id in both places. In the example chat template, the ids are truncated to the last 9 characters. For consistency, I think we should do the same when using the MistralTokenizer. Note that, then, a tool_id with less than 9 characters should still be rejected.
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Thanks for the guidance @tjohnson31415, I made another attempt at it. I'm truncating IDs (if len > 9) and checking that they're mistral-valid, raising an error if it still isn't. Truncating is the only adjustment, I didn't know whether to go as far as dealing with non-alphanumeric chars too (e.g. |
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maybe_serialize_tool_calls
function is meant to work around an issue with how the request is validated by Pydantic; it also has a TODO to be removed after it is fixed in Pydantic. Because of that, I'm thinking that the truncation of the tool call ids should be moved to its own function.
if not re.match(r"^[a-zA-Z0-9]{9}$", tool_call["id"]): | ||
raise ValueError( | ||
"Invalid tool_call ID: %s", | ||
"(must be exactly 9 alphanumeric characters)", | ||
tool_call["id"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation of the tool ids is also done in Mistral Common. We don't need to duplicate the check here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the validation checks.
Signed-off-by: Rafael Vasquez <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
… into fix-mistral-tool-call
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
This PR fixes two cases when using
tokenizer-mode=mistral
to do with tool call IDs incompatible with Mistral.When a request includes the generated tool call message with an id that isn't length 9, results in the error
When a request is sent with
tool_choice
set to request a specific function, the request returns an invalid tool_id:This PR introduces
tool_choice
is set when using a Mistral model