Skip to content

Commit fc6a952

Browse files
mgazzDarkLight1337
authored andcommitted
Add support for Prithvi in Online serving mode (vllm-project#21518)
Signed-off-by: Michele Gazzetti <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 79ae2c6 commit fc6a952

File tree

5 files changed

+128
-10
lines changed

5 files changed

+128
-10
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import base64
5+
import io
6+
7+
import numpy as np
8+
import pytest
9+
import requests
10+
import torch
11+
12+
from ...utils import RemoteOpenAIServer
13+
14+
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
15+
DTYPE = "float16"
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def v1(run_with_both_engines):
20+
# Simple autouse wrapper to run both engines for each test
21+
# This can be promoted up to conftest.py to run for every
22+
# test in a package
23+
pass
24+
25+
26+
@pytest.fixture(scope="module")
27+
def server():
28+
args = [
29+
"--task",
30+
"embed",
31+
# use half precision for speed and memory savings in CI environment
32+
"--dtype",
33+
DTYPE,
34+
"--enforce-eager",
35+
"--trust-remote-code",
36+
"--skip-tokenizer-init",
37+
"--max-num-seqs",
38+
"32"
39+
]
40+
41+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
42+
yield remote_server
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
47+
async def test_single_request(server: RemoteOpenAIServer, model_name: str):
48+
49+
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
50+
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
51+
52+
buffer_tiff = io.BytesIO()
53+
torch.save(pixel_values, buffer_tiff)
54+
buffer_tiff.seek(0)
55+
binary_data = buffer_tiff.read()
56+
base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8')
57+
58+
buffer_coord = io.BytesIO()
59+
torch.save(location_coords, buffer_coord)
60+
buffer_coord.seek(0)
61+
binary_data = buffer_coord.read()
62+
base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8')
63+
64+
prompt = {
65+
"model":
66+
model_name,
67+
"additional_data": {
68+
"prompt_token_ids": [1]
69+
},
70+
"encoding_format":
71+
"base64",
72+
"messages": [{
73+
"role":
74+
"user",
75+
"content": [{
76+
"type": "image_embeds",
77+
"image_embeds": {
78+
"pixel_values": base64_tensor_embedding,
79+
"location_coords": base64_coord_embedding,
80+
},
81+
}],
82+
}]
83+
}
84+
85+
# test single pooling
86+
response = requests.post(server.url_for("pooling"), json=prompt)
87+
response.raise_for_status()
88+
89+
output = response.json()["data"][0]['data']
90+
91+
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
92+
93+
assert len(np_response) == 524288

vllm/engine/multiprocessing/client.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,16 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
9797
self.model_config = engine_config.model_config
9898
self.decoding_config = engine_config.decoding_config
9999

100-
# Create the tokenizer group.
101-
self.tokenizer = init_tokenizer_from_configs(
102-
model_config=self.model_config,
103-
scheduler_config=engine_config.scheduler_config,
104-
lora_config=engine_config.lora_config)
100+
if self.vllm_config.model_config.skip_tokenizer_init:
101+
self.tokenizer = None
102+
103+
else:
104+
# Create the tokenizer group.
105+
self.tokenizer = init_tokenizer_from_configs(
106+
model_config=self.model_config,
107+
scheduler_config=engine_config.scheduler_config,
108+
lora_config=engine_config.lora_config)
109+
105110
self.input_preprocessor = InputPreprocessor(self.model_config,
106111
self.tokenizer)
107112

@@ -375,7 +380,10 @@ async def get_input_preprocessor(self) -> InputPreprocessor:
375380
return self.input_preprocessor
376381

377382
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
378-
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
383+
if self.tokenizer is None:
384+
return None
385+
else:
386+
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
379387

380388
async def get_vllm_config(self) -> VllmConfig:
381389
return self.vllm_config

vllm/entrypoints/openai/serving_engine.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,10 @@ async def _preprocess_chat(
880880
_chat_template_kwargs.update(chat_template_kwargs or {})
881881

882882
request_prompt: Union[str, list[int]]
883-
if isinstance(tokenizer, MistralTokenizer):
883+
884+
if tokenizer is None:
885+
request_prompt = "placeholder"
886+
elif isinstance(tokenizer, MistralTokenizer):
884887
request_prompt = apply_mistral_chat_template(
885888
tokenizer,
886889
messages=messages,
@@ -910,7 +913,14 @@ async def _preprocess_chat(
910913
request = tool_parser(tokenizer).adjust_request( # type: ignore
911914
request=request)
912915

913-
if isinstance(request_prompt, str):
916+
if tokenizer is None:
917+
assert isinstance(request_prompt, str), (
918+
"Prompt has to be a string", \
919+
"when the tokenizer is not initialised"
920+
)
921+
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
922+
prompt_token_ids=[1])
923+
elif isinstance(request_prompt, str):
914924
prompt_inputs = await self._tokenize_prompt_input_async(
915925
request,
916926
tokenizer,

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ async def create_pooling(
9696
self.max_model_len, truncate_prompt_tokens)
9797
lora_request = self._maybe_get_adapters(request)
9898

99-
tokenizer = await self.engine_client.get_tokenizer(lora_request)
99+
if self.model_config.skip_tokenizer_init:
100+
tokenizer = None
101+
else:
102+
tokenizer = await self.engine_client.get_tokenizer(lora_request
103+
)
100104

101105
if isinstance(request, PoolingChatRequest):
102106
(

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ def apply(
103103
mm_kwargs = {}
104104

105105
for k, v in mm_data.items():
106-
mm_kwargs[k] = v
106+
if isinstance(v, dict) and k == "image":
107+
mm_kwargs.update(v)
108+
else:
109+
mm_kwargs[k] = v
107110
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
108111

109112
# This model receives in input a multi-dimensional tensor representing

0 commit comments

Comments
 (0)