21
21
from vllm .model_executor .models .clip import CLIPVisionModel
22
22
from vllm .model_executor .models .llama import LlamaModel
23
23
from vllm .model_executor .sampling_metadata import SamplingMetadata
24
- from vllm .multimodal import MULTIMODAL_REGISTRY , BatchedTensors
24
+ from vllm .multimodal import MULTIMODAL_REGISTRY
25
25
from vllm .sequence import IntermediateTensors , SamplerOutput
26
26
27
27
from .clip import (dummy_image_for_clip , dummy_seq_data_for_clip ,
43
43
44
44
class LlavaNextImagePixelInputs (TypedDict ):
45
45
type : Literal ["pixel_values" ]
46
- data : BatchedTensors
46
+ data : Union [ torch . Tensor , List [ torch . Tensor ]]
47
47
"""
48
48
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
49
49
@@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
62
62
LlavaNextImageInputs = LlavaNextImagePixelInputs
63
63
64
64
65
- # Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
66
- # NOTE: new_height and new_width are further incremented to properly invert the
67
- # floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
65
+ # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
68
66
def _get_llava_next_num_unpadded_features (
69
- height : int ,
70
- width : int ,
67
+ original_height : int ,
68
+ original_width : int ,
71
69
npatches : int ,
72
70
num_patch_height : int ,
73
71
num_patch_width : int ,
74
72
) -> Tuple [int , int ]:
75
73
current_height = npatches * num_patch_height
76
74
current_width = npatches * num_patch_width
77
- current_height = torch .tensor (current_height ).to ("cuda" )
78
- current_width = torch .tensor (current_width ).to ("cuda" )
79
75
80
- aspect_ratio : float = width / height
81
- current_aspect_ratio : float = current_width / current_height
76
+ aspect_ratio = original_width / original_height
77
+ current_aspect_ratio = current_width / current_height
78
+
82
79
if aspect_ratio > current_aspect_ratio :
83
- scale_factor = current_width / width
84
- new_height = int (height * scale_factor )
80
+ new_height = (original_height * current_width ) // original_width
85
81
padding = (current_height - new_height ) // 2
86
82
current_height -= padding * 2
87
83
else :
88
- scale_factor = current_height / height
89
- new_width = int (width * scale_factor )
84
+ new_width = (original_width * current_height ) // original_height
90
85
padding = (current_width - new_width ) // 2
91
86
current_width -= padding * 2
92
87
@@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
95
90
return (unpadded_features , newline_features )
96
91
97
92
98
- # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4 /server/text_generation_server/models/vlm_causal_lm.py#L111
93
+ # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0 /server/text_generation_server/models/vlm_causal_lm.py#L106
99
94
def get_llava_next_image_feature_size (
100
95
hf_config : LlavaNextConfig ,
101
96
* ,
@@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
111
106
)
112
107
base_feature_size = num_patches * num_patches
113
108
114
- # Note: We follow the "wrong" width/height order
115
- # [ref: PR huggingface/transformers#31588]
116
- num_patch_width , num_patch_height = get_anyres_image_grid_shape (
109
+ num_patch_height , num_patch_width = get_anyres_image_grid_shape (
117
110
image_size = (input_height , input_width ),
118
111
grid_pinpoints = hf_config .image_grid_pinpoints ,
119
112
patch_size = vision_config .image_size ,
@@ -349,11 +342,12 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
349
342
if patch_embeddings .shape [0 ] > 1 :
350
343
other_patch_embeds = patch_embeddings [1 :]
351
344
345
+ # Move to CPU to avoid floating-point errors
346
+ orig_height , orig_width = image_size .tolist ()
347
+
352
348
# image_aspect_ratio == "anyres"
353
- # Note: We follow the "wrong" width/height order
354
- # [ref: PR huggingface/transformers#31588]
355
- num_patch_width , num_patch_height = get_anyres_image_grid_shape (
356
- image_size ,
349
+ num_patch_height , num_patch_width = get_anyres_image_grid_shape (
350
+ (orig_height , orig_width ),
357
351
self .config .image_grid_pinpoints ,
358
352
self .config .vision_config .image_size ,
359
353
)
@@ -365,7 +359,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
365
359
.permute (4 , 0 , 2 , 1 , 3 ).contiguous () \
366
360
.flatten (1 , 2 ).flatten (2 , 3 )
367
361
other_patch_embeds = unpad_image (other_patch_embeds ,
368
- image_size )
362
+ ( orig_height , orig_width ) )
369
363
other_patch_embeds = torch .cat ((
370
364
other_patch_embeds ,
371
365
self .image_newline [:, None , None ] \
@@ -398,7 +392,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
398
392
def _process_image_pixels (
399
393
self ,
400
394
inputs : LlavaNextImagePixelInputs ,
401
- ) -> BatchedTensors :
395
+ ) -> Union [ torch . Tensor , List [ torch . Tensor ]] :
402
396
assert self .vision_tower is not None
403
397
404
398
pixel_values = inputs ["data" ]
@@ -425,7 +419,9 @@ def _process_image_pixels(
425
419
]
426
420
427
421
def _process_image_input (
428
- self , image_input : LlavaNextImageInputs ) -> BatchedTensors :
422
+ self ,
423
+ image_input : LlavaNextImageInputs ,
424
+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
429
425
patch_embeddings = self ._process_image_pixels (image_input )
430
426
431
427
image_sizes = image_input .get ("image_sizes" )
0 commit comments