Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Implement merged multimodal processor for Mllama #11427

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
56 changes: 51 additions & 5 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Tuple, Union, cast

from typing_extensions import assert_never

Expand Down Expand Up @@ -485,6 +485,40 @@
decoder=decoder_inputs,
)

def _handle_multimodal_enc_dec_inputs(
self,
inputs: SingletonInputs,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)

Check failure on line 502 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/inputs/preprocess.py:502:27: F821 Undefined name `MultiModalEncDecInputs`
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
decoder_inputs = MultiModalInputsV2(

Check failure on line 507 in vllm/inputs/preprocess.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/inputs/preprocess.py:507:30: F821 Undefined name `MultiModalInputsV2`
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = inputs
else:
raise AssertionError("This line should be unreachable.")
return encoder_inputs, decoder_inputs

def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
Expand Down Expand Up @@ -538,12 +572,18 @@
request_id=request_id,
)
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._handle_multimodal_enc_dec_inputs(inputs))
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down Expand Up @@ -574,12 +614,18 @@
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_inputs = await self._prompt_to_llm_inputs_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
if self._can_process_multimodal():
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._handle_multimodal_enc_dec_inputs(inputs))
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down
232 changes: 78 additions & 154 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,19 @@
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from torch import nn
from transformers import BatchFeature, MllamaConfig
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
MllamaProcessor, get_cross_attention_token_mask)

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -54,8 +52,9 @@
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
from vllm.utils import is_list_of
from vllm.multimodal.processing import (EncDecMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)

from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
Expand All @@ -81,158 +80,86 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs


def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images


def input_processor_for_mllama(
ctx: InputContext,
inputs: EncoderDecoderInputs,
) -> EncoderDecoderInputs:
# Example input to processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }

# move encoder prompt to decoder
dec_inputs = TokenInputs(**inputs["encoder"])

multi_modal_data = dec_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
# text-only
return EncoderDecoderInputs(
encoder=token_inputs([]),
decoder=dec_inputs,
)

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]

assert is_list_of(image_data, Image.Image)

num_image_tokens = dec_inputs['prompt_token_ids'].count(
MLLAMA_IMAGE_TOKEN_ID)
if num_image_tokens != len(image_data):
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({len(image_data)})")

# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
dec_inputs["prompt_token_ids"])

hf_config = ctx.model_config.hf_config
vision_config = hf_config.vision_config

num_tiles = 0
for image in image_data[::-1]:
width, height = image.size
tile_size = vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=vision_config.max_num_tiles,
tile_size=tile_size,
)
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break

# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk

# Example output from processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
return EncoderDecoderInputs(
encoder=token_inputs(
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
multi_modal_data=multi_modal_data,
),
decoder=dec_inputs,
)


def get_max_mllama_image_tokens(ctx: InputContext) -> int:
hf_config = ctx.model_config.hf_config
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
return hf_config.vision_config.max_num_tiles * token_per_chunk


def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images"

return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_images),
(0, seq_len - num_images),
)

class MllamaMultiModalProcessor(EncDecMultiModalProcessor):

def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images

return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))


def dummy_image(num_images: int, ):
width = height = 1024
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}


def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def _get_hf_processor(self) -> MllamaProcessor:
return self.ctx.get_hf_processor(MllamaProcessor)

def _call_hf_processor(
self,
hf_processor: MllamaProcessor,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# The MllamaProcessor calling drop `num_tiles` from image_processor,
# while `num_tiles` is essential for forwarding in vLLM implementation.
# Therefore, we use image_processor calling to keep `num_tiles`.
image_processor = hf_processor.image_processor
image_features = image_processor(**processor_data)

tokenizer = self._get_tokenizer()
encoding = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
data = dict(**encoding, **image_features)

return BatchFeature(data=data, tensor_type="pt")

def _create_encoder_prompt(self, prompt: str):
hf_processor = self._get_hf_processor()
image_token = hf_processor.image_token
num_images = prompt.count(image_token)
return image_token * num_images

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
vision_config = self.ctx.get_hf_config(MllamaConfig).vision_config
assert vision_config.image_size % 14 == 0, (
"chunk size should be multiple of 14")
token_per_chunk = (vision_config.image_size // 14)**2 + 1
hf_processor = self._get_hf_processor()
image_token_id = hf_processor.image_token_id

def get_replacement_mllama(item_idx):
num_tile = hf_inputs["num_tiles"][0][item_idx]
num_tokens = num_tile * token_per_chunk
return [image_token_id] * num_tokens

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_mllama,
)
]

def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token

width = height = 1024
image = Image.new("RGB", (width, height), color=0)

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data={"image": image},
mm_processor_kwargs={},
)


def _prepare_aspect_ratio_attention_mask(
Expand Down Expand Up @@ -1107,11 +1034,8 @@ def forward(
return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
Expand Down
16 changes: 16 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,19 @@
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""


class MultiModalEncDecInputs(MultiModalInputsV2):

Check failure on line 539 in vllm/multimodal/inputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/multimodal/inputs.py:539:30: F821 Undefined name `MultiModalInputsV2`
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""

encoder_prompt: str
"""The processed encoder prompt text."""

encoder_prompt_token_ids: List[int]

Check failure on line 548 in vllm/multimodal/inputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/multimodal/inputs.py:548:31: F821 Undefined name `List`
"""The processed token IDs of the encoder prompt."""

encoder_token_type_ids: NotRequired[List[int]]

Check failure on line 551 in vllm/multimodal/inputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/multimodal/inputs.py:551:41: F821 Undefined name `List`
"""The token type IDs of the encoder prompt."""
Loading