Skip to content

Add support for Prithvi geospatial model in serving mode #20307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(self):
self.model = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
dtype="float16",
enforce_eager=True,
)

def run(self, input_data, location_coords):
Expand Down
50 changes: 50 additions & 0 deletions tests/models/multimodal/pooling/test_prithvi_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from ....conftest import VllmRunner


def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data


def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:

mm_data = generate_test_mm_data()
prompt = {
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": mm_data
}
with vllm_runner(model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True) as vllm_model:
vllm_model.encode(prompt)


MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]


@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ def __post_init__(self) -> None:
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self._init_model_supports_multimodal_raw_input())
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

Expand Down Expand Up @@ -706,6 +708,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:

return None

def _init_model_supports_multimodal_raw_input(self):
return self.registry.supports_multimodal_raw_input(self.architectures)

def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
Expand Down Expand Up @@ -1100,10 +1105,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)

def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)

@property
def is_deepseek_mla(self) -> bool:
Expand Down Expand Up @@ -1397,6 +1402,10 @@ def uses_mrope(self) -> bool:
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

@property
def is_pooling_model(self) -> bool:
return self.registry.is_pooling_model(self.architectures)

@property
def is_cross_encoder(self) -> bool:
Expand Down
99 changes: 72 additions & 27 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""The type of the content part."""


class ChatCompletionContentPartTensorsParam(TypedDict, total=False):
tensors: Required[Union[str, dict[str, str]]]
"""
The tensors. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["tensors"]]
"""The type of the content part."""


class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Expand Down Expand Up @@ -129,6 +140,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartTensorsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]

Expand Down Expand Up @@ -468,7 +480,7 @@ def resolve_chat_template_content_format(



ModalityStr = Literal["image", "audio", "video", "image_embeds"]
ModalityStr = Literal["image", "audio", "video", "image_embeds","tensors"]
_T = TypeVar("_T")


Expand Down Expand Up @@ -572,6 +584,8 @@ def _placeholder_str(self, modality: ModalityStr,
return self._cached_token_str(self._tokenizer,
hf_config.video_token_index)
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "tensors":
return None
else:
raise TypeError(f"Unknown modality: {modality}")

Expand Down Expand Up @@ -630,6 +644,13 @@ def all_mm_data(self) -> Optional[MultiModalDataDict]:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]

if "tensors" in items_by_modality:
tensors_lst = items_by_modality["tensors"]
if len(tensors_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'tensors'}")
mm_inputs["tensors"] = tensors_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
Expand Down Expand Up @@ -663,6 +684,12 @@ async def all_mm_data(self) -> Optional[MultiModalDataDict]:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
if "tensors" in items_by_modality:
tensors_lst = items_by_modality["tensors"]
if len(tensors_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'tensors'}")
mm_inputs["tensors"] = tensors_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
Expand Down Expand Up @@ -695,8 +722,9 @@ def parse_image(self, image_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
def parse_tensors(self,
tensor_encodings: Union[str, dict[str, str]],
modality_str: ModalityStr) -> None:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -729,18 +757,22 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
def parse_tensors(self,
tensor_encodings: Union[str, dict[str, str]],
modality_str: ModalityStr) -> None:
if modality_str not in ["image_embeds","tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors = {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
placeholder = self._tracker.add("image_embeds", embeds)
placeholder = self._tracker.add(modality_str, tensors)

if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
if isinstance(tensor_encodings, str):
tensor= self._connector.fetch_tensor_encoding(tensor_encodings)
placeholder = self._tracker.add(modality_str, tensor)

self._add_placeholder(placeholder)

Comment on lines +760 to 778
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential UnboundLocalError in parse_tensors. If tensor_encodings is neither a dict nor a str, placeholder will not be assigned before being used in self._add_placeholder(placeholder).

To fix this, add an else clause that raises a TypeError for unsupported types and ensure placeholder is always assigned.

Suggested change
def parse_tensors(self,
tensor_encodings: Union[str, dict[str, str]],
modality_str: ModalityStr) -> None:
if modality_str not in ["image_embeds","tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors = {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
placeholder = self._tracker.add("image_embeds", embeds)
placeholder = self._tracker.add(modality_str, tensors)
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
if isinstance(tensor_encodings, str):
tensor= self._connector.fetch_tensor_encoding(tensor_encodings)
placeholder = self._tracker.add(modality_str, tensor)
self._add_placeholder(placeholder)
def parse_tensors(self,
tensor_encodings: Union[str, dict[str, str]],
modality_str: ModalityStr) -> None:
if modality_str not in ["image_embeds", "tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors = {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
placeholder = self._tracker.add(modality_str, tensors)
elif isinstance(tensor_encodings, str):
tensor= self._connector.fetch_tensor_encoding(tensor_encodings)
placeholder = self._tracker.add(modality_str, tensor)
else:
raise TypeError(f"Unsupported type for tensor_encodings: {type(tensor_encodings)}")
self._add_placeholder(placeholder)

Expand Down Expand Up @@ -780,23 +812,27 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
def parse_tensors(self,
tensor_encodings: Union[str, dict[str, str]],
modality_str: ModalityStr) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()

if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
if modality_str not in ["image_embeds","tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors= {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
future.set_result(embeds)
future.set_result(tensors)

if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)
if isinstance(tensors, str):
tensor= self._connector.\
fetch_tensor_encoding(tensor_encodings)
future.set_result(tensor)
Comment on lines 819 to +833
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This asynchronous version of parse_tensors has a critical bug. On line 830, isinstance(tensors, str) will raise a NameError if tensor_encodings is a string, because tensors is only defined in the preceding if block. This should be isinstance(tensor_encodings, str).

Additionally, similar to the synchronous version, using if/elif/else would make the logic more robust.

Suggested change
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
if modality_str not in ["image_embeds","tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors= {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
future.set_result(embeds)
future.set_result(tensors)
if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)
if isinstance(tensors, str):
tensor= self._connector.\
fetch_tensor_encoding(tensor_encodings)
future.set_result(tensor)
if modality_str not in ["image_embeds","tensors"]:
raise Exception("tensors are acceptable only as part "
"of 'image_embeds' or 'tensors' modalities.")
if isinstance(tensor_encodings, dict):
tensors= {
k: self._connector.fetch_tensor_encoding(v)
for k, v in tensor_encodings.items()
}
future.set_result(tensors)
elif isinstance(tensor_encodings, str):
tensor= self._connector.\
fetch_tensor_encoding(tensor_encodings)
future.set_result(tensor)
else:
raise TypeError(f"Unsupported type for tensor_encodings: {type(tensor_encodings)}")


placeholder = self._tracker.add("image_embeds", future)
placeholder = self._tracker.add(modality_str, future)
self._add_placeholder(placeholder)
Comment on lines +815 to 836
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In parse_tensors, there's a potential UnboundLocalError. The tensors variable is only assigned within the if isinstance(tensor_encodings, dict): block. If tensor_encodings is a string, tensors will not be defined, leading to an error when used in isinstance(tensors, str).

To fix this, use tensor_encodings instead of tensors in the isinstance check.

        if isinstance(tensor_encodings, dict):
            tensors= {
                k: self._connector.fetch_tensor_encoding(v)
                for k, v in tensor_encodings.items()
            }
            future.set_result(tensors)

        if isinstance(tensor_encodings, str):
            tensor= self._connector.\
                fetch_tensor_encoding(tensor_encodings)
            future.set_result(tensor)


def parse_audio(self, audio_url: str) -> None:
Expand All @@ -819,6 +855,8 @@ def parse_video(self, video_url: str) -> None:
self._add_placeholder(placeholder)




def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
Expand Down Expand Up @@ -915,6 +953,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_TensorsParser = partial(cast, ChatCompletionContentPartTensorsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
# Need to validate url objects
Expand All @@ -935,6 +974,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"tensors":
lambda part: _TensorsParser(part).get("tensors", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio":
Expand Down Expand Up @@ -1004,7 +1045,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"image_embeds", "tensors",
"audio_url", "input_audio", "video_url")


Expand Down Expand Up @@ -1081,8 +1122,12 @@ def _parse_chat_message_content_part(
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
mm_parser.parse_tensors(content,"image_embeds")
return {'type': 'image'} if wrap_dicts else None
if part_type == "tensors":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_tensors(content,"tensors")
return {'type': 'tensors'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
model: Optional[str] = None
messages: list[ChatCompletionMessageParam]

encoding_format: Literal["float", "base64"] = "float"
encoding_format: Literal["float", "base64", "tensor"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ async def _preprocess_chat(
messages=messages,
**_chat_template_kwargs,
)
elif tokenizer is None:
request_prompt = "placeholder"
else:
request_prompt = apply_hf_chat_template(
tokenizer=tokenizer,
Expand All @@ -831,7 +833,17 @@ async def _preprocess_chat(
request = tool_parser(tokenizer).adjust_request( # type: ignore
request=request)

if isinstance(request_prompt, str):
if tokenizer is None:
prompt_inputs = {}
if "prompt_token_ids" not in request.additional_data:
raise Exception("Request must contain "
"additional_data['prompt_token_ids'] "
"when the tokenizer is not initialised")

prompt_inputs["prompt_token_ids"] = request.additional_data[
"prompt_token_ids"]

elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
Expand Down
14 changes: 11 additions & 3 deletions vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.multimodal.image import ImageEmbeddingMediaIO
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators

Expand All @@ -33,7 +34,7 @@

def _get_data(
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
encoding_format: Literal["float", "base64", "tensors"],
) -> Union[list[float], str]:
if encoding_format == "float":
return output.data.tolist()
Expand All @@ -43,6 +44,9 @@ def _get_data(
pt_float32 = output.data.to(dtype=torch.float32)
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
elif encoding_format == "tensor":
tensor_encoding_io = ImageEmbeddingMediaIO()
return tensor_encoding_io.encode_base64(output.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this. The pooler output is a torch.Tensor but ImageEmbeddingMediaIO.encode_base64 expects an image. How does this work?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I highlighted this in the RFC as part of Additional features explored to enable tensors:.

ImageMediaIO expects an image. ImageEmbeddingMediaIO expects Tensors. I thought about reusing ImageEmbeddingMediaIO to encode tensors output but I had a hard time reconstructing the tensor on the user side when encoded via encode_base64. that is why I added encode_tensor

If this makes it too confusing we can have a dedicated class like TensorIO that performs the encoding so that we can keep separation of concerns.

Copy link
Author

@mgazz mgazz Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I run a couple of tests using ImageEmbeddingMediaIO and the behaviour is strange.

Here an example

from vllm.multimodal.image import ImageEmbeddingMediaIO
import torch

pixel_values = torch.full((6, 512, 512), 1.0,dtype=torch.float16)
image_embeds_media_io = ImageEmbeddingMediaIO()
encoded = image_embeds_media_io.encode_base64(pixel_values)
decoded = image_embeds_media_io.load_base64("",encoded)

Here the error:

(myenv) mgazz@mgazz-vllm-devpod-6c47989df9-hstsz:~/vllm$ python test.py 
Traceback (most recent call last):
  File "/workspace/vllm/test.py", line 7, in <module>
    decoded = image_embeds_media_io.load_base64("",encoded)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/image.py", line 91, in load_base64
    return self.load_bytes(base64.b64decode(data))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/image.py", line 88, in load_bytes
    return torch.load(buffer, weights_only=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/myenv/lib/python3.12/site-packages/torch/serialization.py", line 1548, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 0

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

If we encoded via encode_tensor that is part of this PR it works fine.

from vllm.multimodal.image import ImageEmbeddingMediaIO
import torch

pixel_values = torch.full((6, 512, 512), 1.0,dtype=torch.float16)
image_embeds_media_io = ImageEmbeddingMediaIO()
encoded = image_embeds_media_io.encode_tensor(pixel_values)
decoded = image_embeds_media_io.load_base64("",encoded)
print(type(decoded))

Looks like a bug in the current implementation. Maybe the solution is updating encode_base64 with the implementation used in encode_tensor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also see Gemini is suggesting against using tensor.save ... #20307 (comment)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a separate issue on this: #20427


assert_never(encoding_format)

Expand Down Expand Up @@ -99,7 +103,11 @@ async def create_pooling(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)
if not self.model_config.skip_tokenizer_init:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
else:
tokenizer = None

if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
Expand Down Expand Up @@ -205,7 +213,7 @@ def request_output_to_pooling_response(
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
encoding_format: Literal["float", "base64", "tensors"],
) -> PoolingResponse:
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
Expand Down
Loading