Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral #12332

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

rafvasq
Copy link
Contributor

@rafvasq rafvasq commented Jan 22, 2025

This PR fixes two cases when using tokenizer-mode=mistral to do with tool call IDs incompatible with Mistral.

  • When a request includes the generated tool call message with an id that isn't length 9, results in the error

    mistral_common.exceptions.InvalidFunctionCallException: Tool call id was chatcmpl-tool-e5add885dbb342de950be95dd89b71e7 but must be a-z, A-Z, 0-9, with a length of 9.
    
  • When a request is sent with tool_choice set to request a specific function, the request returns an invalid tool_id:

    "tool_calls": [{
      "id": "chatcmpl-tool-64aa8ec82efa4007b5fbf1ea885dea00",
      "type": "function",
      "function": {
        "name": "get_current_weather",
        "arguments": "{ \"city\": \"Dallas\", \"state\": \"TX\", \"unit\": \"celsius\" }"
      }
    }]
    

This PR introduces

  • Handling invalid tool_ids by truncating and validating them to the required 9 characters.
  • Generating a valid 9-character tool_id when tool_choice is set when using a Mistral model
  • Specify comments about mistral's ID requirement for exactly 9 characters.

Signed-off-by: Rafael Vasquez <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Jan 22, 2025
Copy link
Contributor

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking at this @rafvasq! I know it is still a WIP, but I left some suggestions on the current state of the PR.

vllm/utils.py Outdated
@@ -2206,3 +2211,8 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)

def generate_valid_mistral_tool_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mistral requirement is exactly 9 characters.

Suggested change
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this comment in a couple of places.

if isinstance(tokenizer, MistralTokenizer):
for tool_call in message.tool_calls:
tool_call.id = generate_valid_mistral_tool_id()
logger.warning(f"Assigned new tool_id: {tool_call.id} for tool: {tool_call}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tool_choice: auto case, the MistralTokenizer should be generating valid tool ids itself and not need this to override it. So I think this logic can be moved to the branch of the condition above that checks for type(request.tool_choice) is ChatCompletionNamedToolChoiceParam since that is the case where a tool_id is generated that does not meet the conditions for Mistral.

@@ -61,6 +61,8 @@ def maybe_serialize_tool_calls(request: ChatCompletionRequest):
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
tool_call['id'] = generate_valid_mistral_tool_id()
logger.warning(f"Assigned new tool_id: {tool_call['id']} for tool: {tool_call}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this change will modify the tool_id of every tool_call message in the incoming request with a new random id such that the id will change with each subsequent step in the conversation. I don't think we want that to happen.

Instead, tool_ids that are valid can be passed-through as-is and tool_ids that are not valid for mistral should have a consistent mapping to one that is valid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. I added a check here (using mistral common's validation) so that the id will only be generated/changed if it's invalid.

Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
@rafvasq rafvasq marked this pull request as ready for review January 24, 2025 20:02
@rafvasq rafvasq requested a review from tjohnson31415 January 24, 2025 20:02
vllm/utils.py Outdated
def generate_valid_mistral_tool_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, we can just use the static function MistralToolCall.generate_random_id() instead of writing a new function.

@@ -668,6 +669,10 @@ async def chat_completion_full_generator(
arguments=output.text))
])

if isinstance(tokenizer, MistralTokenizer):
for tool_call in message.tool_calls:
tool_call.id = generate_valid_mistral_tool_id()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that there is a MistralToolCall class that overrides the id generation. It would be a bit cleaner to just use that class, eg. could make this change above:

                tool_call_class = MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        tool_call_class(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])

@@ -62,6 +62,8 @@ def maybe_serialize_tool_calls(request: ChatCompletionRequest):
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
if not re.match(r"^[a-zA-Z0-9]{9}$", tool_call['id']):
tool_call['id'] = generate_valid_mistral_tool_id()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will update the tool_call entries id to a valid id, but if the chat history also includes a tool response message with an id, this code will not adjust it, eg.

  "messages":[
    {
      "role": "user",
      "content": "What is the weather in Dallas Texas?"
    },
    {
        "role": "assistant",
        "content": "",
        "tool_calls": [
          {
            "id": "chatcmpl-asdf",
            "type": "function",
            "function": {
              "name": "get_current_weather",
              "arguments": "{ \"city\": \"Dallas\", \"state\": \"TX\", \"unit\": \"celsius\" }"
            }
          }
        ]
    },
    {
      "role": "tool",
      "tool_call_id": "chatcmpl-asdf",
      "name": "get_current_weather",
      "content": "90 degrees, partly cloudy"
    }
  ],

Also, we'll need to make sure that the id is mapped to the same corrected id in both places. In the example chat template, the ids are truncated to the last 9 characters. For consistency, I think we should do the same when using the MistralTokenizer. Note that, then, a tool_id with less than 9 characters should still be rejected.

@rafvasq rafvasq requested a review from tjohnson31415 January 30, 2025 03:25
@rafvasq
Copy link
Contributor Author

rafvasq commented Jan 30, 2025

Thanks for the guidance @tjohnson31415, I made another attempt at it.

I'm truncating IDs (if len > 9) and checking that they're mistral-valid, raising an error if it still isn't. Truncating is the only adjustment, I didn't know whether to go as far as dealing with non-alphanumeric chars too (e.g. chatcmpl-asdf truncates to cmpl-asdf but it's not alphanumeric and still invalid) so it'll get rejected by the validation step.

Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Copy link
Contributor

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maybe_serialize_tool_calls function is meant to work around an issue with how the request is validated by Pydantic; it also has a TODO to be removed after it is fixed in Pydantic. Because of that, I'm thinking that the truncation of the tool call ids should be moved to its own function.

Comment on lines 74 to 79
if not re.match(r"^[a-zA-Z0-9]{9}$", tool_call["id"]):
raise ValueError(
"Invalid tool_call ID: %s",
"(must be exactly 9 alphanumeric characters)",
tool_call["id"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation of the tool ids is also done in Mistral Common. We don't need to duplicate the check here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the validation checks.

Copy link

mergify bot commented Feb 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rafvasq.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 4, 2025
@mergify mergify bot removed the needs-rebase label Feb 4, 2025
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants