Skip to content

Conversation

mgazz
Copy link
Contributor

@mgazz mgazz commented Jul 24, 2025

This PR builds on top of #20072 and it enables the execution of Prithvi in online serving mode.

This is achieved by increasing the level of support for models that skip the tokenizer initialisation.

A longer description of the what we are trying to achieve is available in #20234.

This supersedes #20307

Test Plan

The PR can be tested as following

pytest  tests/entrypoints/openai/test_skip_tokenizer.py

Additional information

Prithvi in serving mode can be started with the following command:

vllm serve --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' --task embed --trust-remote-code --dtype float16 --skip-tokenizer-init --enforce-eager

The following script provides an example of prompt that can be used to perform an inference (the same is used during the test linked above):

import base64
import requests
import torch
import io
import numpy as np

torch.set_default_dtype(torch.float16)


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client","Content-Type": "application/json"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


def decompress(output):
    np_result = np.frombuffer(
        base64.b64decode(output), dtype=np.float32)
    return np_result.reshape(1, 2, 512, 512)


def main():
    api_url = f"http://localhost:8000/pooling"
    model_name = 'christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'

    pixel_values = torch.full((6, 512, 512), 1.0,dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0,dtype=torch.float16)

    buffer_tiff = io.BytesIO()
    torch.save(pixel_values, buffer_tiff)
    buffer_tiff.seek(0)
    binary_data = buffer_tiff.read()
    base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8')

    buffer_coord = io.BytesIO()
    torch.save(location_coords, buffer_coord)
    buffer_coord.seek(0)
    binary_data = buffer_coord.read()
    base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8')

    prompt={
        "model":model_name,
        "additional_data":{
            "prompt_token_ids": [1]
        },
        "encoding_format": "base64",
        "messages":[
            {
                "role": "user",
                "content": [
                        { "type": "image_embeds",
                        "image_embeds": {
                            "pixel_values": base64_tensor_embedding,
                            "location_coords": base64_coord_embedding,
                            },
                        }
                        ],
            }]}

    pooling_response = post_http_request(prompt=prompt, api_url=api_url)
    numpy_data = decompress(pooling_response.json()["data"][0]["data"])
    print(f"Returned result: {numpy_data}")


if __name__ == "__main__":
    main()

@DarkLight1337 @maxdebayser @njhill @christian-pinto

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Jul 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces support for Prithvi in online serving mode by allowing models to skip tokenizer initialization. The changes include modifications to the client, serving engine, serving pooling, and the Prithvi model itself. It's important to ensure that the tokenizer is checked for None before being used and that exceptions provide sufficient context for debugging.

model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
lora_config=engine_config.lora_config)

self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If self.tokenizer is None, this will raise an AttributeError. It's critical to ensure self.tokenizer is checked for None before being used here to prevent a crash. Consider adding a condition to skip this line if self.tokenizer is None.

        self.input_preprocessor = InputPreprocessor(self.model_config,
                                                    self.tokenizer if self.tokenizer else None)

Comment on lines 917 to 920
if "prompt_token_ids" not in request.additional_data:
raise Exception("Request must contain "
"additional_data['prompt_token_ids'] "
"when the tokenizer is not initialised")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This exception lacks context about why the tokenizer is not initialized. Add more details to the exception message to aid debugging. For example, include the model name or a hint to check the --skip-tokenizer-init flag.

                raise Exception("Request must contain "
                                "additional_data['prompt_token_ids'] "
                                "when the tokenizer is not initialised. Check if '--skip-tokenizer-init' flag was used correctly for model {}".format(request.model))

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also fix pre-commit

mgazz and others added 3 commits July 24, 2025 13:59
Improve multimodal input handling logic

Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Michele Gazzetti <[email protected]>
@mgazz mgazz force-pushed the online_prithvi_no_tokenizer branch from e5dec1e to c2339d1 Compare July 24, 2025 14:08
mgazz added 3 commits July 24, 2025 14:22
Signed-off-by: Michele Gazzetti <[email protected]>
Signed-off-by: Michele Gazzetti <[email protected]>
Signed-off-by: Michele Gazzetti <[email protected]>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding support!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 24, 2025 15:17
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 24, 2025
@mgazz
Copy link
Contributor Author

mgazz commented Jul 24, 2025

@DarkLight1337 thank you for the support during the review process.

@DarkLight1337
Copy link
Member

PTAL at the failing entrypoints test

auto-merge was automatically disabled July 25, 2025 08:05

Head branch was pushed to by a user without write access

@mgazz
Copy link
Contributor Author

mgazz commented Jul 25, 2025

Apologies about the test. I will keep an eye on it.

Signed-off-by: Michele Gazzetti <[email protected]>
@vllm-bot vllm-bot merged commit e189b50 into vllm-project:main Jul 25, 2025
66 of 68 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: shuw <[email protected]>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Michele Gazzetti <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants