Skip to content

Commit be0f542

Browse files
Vlm test (#339)
Adding vlm test script changes to the original PR --------- Signed-off-by: quic-dhirajku <[email protected]>
1 parent 680e72b commit be0f542

File tree

7 files changed

+478
-266
lines changed

7 files changed

+478
-266
lines changed

QEfficient/transformers/models/mllama/modeling_mllama.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def forward(
127127
value_states_new = torch.index_put(value_states_old, indices, value_states)
128128

129129
# Select old or new image KV states based on q_len
130-
key_states = torch.where(q_len == 1, key_states_old, key_states_new)
131-
value_states = torch.where(q_len == 1, value_states_old, value_states_new)
130+
key_states = torch.where(torch.tensor(q_len == 1), key_states_old, key_states_new)
131+
value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new)
132132

133133
# Update the image cache
134134
past_key_value.key_cache[self.layer_idx] = key_states
@@ -1113,7 +1113,7 @@ def forward(
11131113
cache_position=cache_position,
11141114
num_logits_to_keep=num_logits_to_keep,
11151115
)
1116-
1116+
outputs["pixel_values"] = pixel_values
11171117
return outputs
11181118

11191119
def get_dummy_inputs(self, kv_offload: bool = False):
@@ -1281,6 +1281,8 @@ def get_output_names(self, kv_offload: bool = False):
12811281
"logits",
12821282
*[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]],
12831283
]
1284+
if not kv_offload:
1285+
lang_output_names.append("pixel_values_RetainedState")
12841286

12851287
output_names = {}
12861288
if kv_offload:

QEfficient/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
check_and_assign_cache_dir,
1414
dump_qconfig,
1515
get_num_layers_from_config,
16+
get_num_layers_vlm,
1617
get_onnx_dir_name,
1718
get_padding_shape_from_config,
19+
get_padding_shape_vlm,
1820
get_qpc_dir_path,
1921
hf_download,
2022
load_hf_tokenizer,
2123
login_and_download_hf_lm,
2224
onnx_exists,
2325
padding_check_and_fix,
2426
qpc_exists,
25-
get_padding_shape_vlm,
26-
get_num_layers_vlm,
2727
)

QEfficient/utils/_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,14 @@ def get_padding_shape_vlm(config, ctx_len, batch_size=1):
387387
Return:
388388
List[int, int, int, int]
389389
"""
390-
if hasattr(config, "architectures") and "LlavaForConditionalGeneration" in config.architectures:
390+
if hasattr(config, "text_config"):
391391
n_heads = config.text_config.num_key_value_heads
392392
d_head = config.text_config.hidden_size // config.text_config.num_attention_heads
393393
padding_shape = [batch_size, n_heads, ctx_len, d_head]
394-
elif hasattr(config, "architectures") and "MllamaForConditionalGeneration" in config.architectures:
395-
padding_shape = []
394+
elif hasattr(config, "llm_config"):
395+
n_heads = config.llm_config.num_key_value_heads
396+
d_head = config.llm_config.hidden_size // config.llm_config.num_attention_heads
397+
padding_shape = [batch_size, n_heads, ctx_len, d_head]
396398
return padding_shape
397399

398400

QEfficient/utils/generate_inputs.py

+184-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
from typing import List
78

89
import numpy as np
910
import torch
@@ -12,7 +13,6 @@
1213
get_num_layers_from_config,
1314
get_padding_shape_from_config,
1415
padding_check_and_fix,
15-
get_padding_shape_vlm,
1616
)
1717

1818

@@ -206,27 +206,108 @@ def update_ort_outputs(self, ort_outputs):
206206

207207

208208
class InputHandlerVLM:
209-
def __init__(self, batch_size, config, image, conversation, processor, prompt, ctx_len, n_layer):
209+
def __init__(
210+
self, batch_size, config, image, conversation, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer
211+
):
210212
self.ctx_len = ctx_len
213+
self.prompt_len = prompt_len
214+
self.max_gen_len = max_gen_len
211215
self.config = config
212216
self.image = image
213217
self.prompt = prompt
214218
self.batch_size = batch_size
215-
self.padding_shape = get_padding_shape_vlm(config, ctx_len, batch_size)
216219
self.n_layer = n_layer
217220
self.processor = processor
218221
self.conversation = conversation
219222

223+
def prepare_pytorch_inputs(self):
224+
"""
225+
Function responsible for creating Prefill stage tensor inputs for PyTorch model.
226+
227+
Return:
228+
:Dict: input_ids, position_ids, past_key_values
229+
"""
230+
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
231+
if hasattr(self.config, "text_config"):
232+
txt_cfg = self.config.text_config
233+
else:
234+
txt_cfg = self.config.llm_config
235+
236+
num_hidden_layers = txt_cfg.num_hidden_layers
237+
num_key_value_heads = txt_cfg.num_key_value_heads
238+
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
239+
if hasattr(txt_cfg, "cross_attention_layers"):
240+
cross_attention_layers = txt_cfg.cross_attention_layers
241+
242+
vis_cfg = self.config.vision_config
243+
num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1
244+
image_tokens_len = vis_cfg.max_num_tiles * num_patches
245+
246+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
247+
inputs["past_key_values"] = []
248+
for i in range(num_hidden_layers):
249+
# Specific to mllama as of now
250+
if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers:
251+
idx = cross_attention_layers.index(i)
252+
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
253+
inputs["past_key_values"].append(
254+
(
255+
torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim),
256+
torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim),
257+
)
258+
)
259+
else:
260+
inputs["past_key_values"].append(
261+
(
262+
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
263+
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
264+
)
265+
)
266+
267+
return inputs
268+
220269
def prepare_vlm_ort_inputs(self):
270+
if hasattr(self.config, "text_config"):
271+
txt_cfg = self.config.text_config
272+
else:
273+
txt_cfg = self.config.llm_config
274+
num_hidden_layers = txt_cfg.num_hidden_layers
275+
num_key_value_heads = txt_cfg.num_key_value_heads
276+
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
277+
if hasattr(txt_cfg, "cross_attention_layers"):
278+
cross_attention_layers = txt_cfg.cross_attention_layers
279+
vis_cfg = self.config.vision_config
280+
num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1
281+
image_tokens_len = vis_cfg.max_num_tiles * num_patches
282+
221283
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="np")
222284
if "attention_mask" in inputs.keys():
223-
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
285+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
224286
inputs["past_key_values"] = []
225-
for i in range(self.n_layer[0]):
226-
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
227-
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
228287

229-
return inputs
288+
vision_inputs = {
289+
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}
290+
}
291+
292+
for i in range(num_hidden_layers):
293+
if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers:
294+
idx = cross_attention_layers.index(i)
295+
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
296+
inputs["past_key." + str(i)] = np.zeros(
297+
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
298+
)
299+
inputs["past_value." + str(i)] = np.zeros(
300+
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
301+
)
302+
else:
303+
inputs["past_key." + str(i)] = np.zeros(
304+
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
305+
)
306+
inputs["past_value." + str(i)] = np.zeros(
307+
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
308+
)
309+
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
310+
return vision_inputs, lang_inputs
230311

231312
def update_vlm_ort_outputs(self, ort_outputs):
232313
"""
@@ -238,7 +319,6 @@ def update_vlm_ort_outputs(self, ort_outputs):
238319
Return:
239320
updated_outputs (Dict): Updated past_key_values, logits, pixel_values
240321
"""
241-
242322
present_key_values = []
243323
for i in range(self.n_layer[0]):
244324
if "past_key." + str(i) + "_RetainedState" in ort_outputs:
@@ -252,6 +332,9 @@ def update_vlm_ort_outputs(self, ort_outputs):
252332
outputs["pixel_values_RetainedState"] = (
253333
ort_outputs["pixel_values_RetainedState"] if "pixel_values_RetainedState" in ort_outputs else None
254334
)
335+
outputs["image_features_RetainedState"] = (
336+
ort_outputs["image_features_RetainedState"] if "image_features_RetainedState" in ort_outputs else None
337+
)
255338
return outputs
256339

257340
def update_vlm_ort_inputs(self, inputs, ort_outputs):
@@ -265,7 +348,6 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs):
265348
Return:
266349
:Dict: Updated input_ids, position_ids, pixel_values and past_key_values
267350
"""
268-
269351
updated_inputs = {}
270352
updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1)
271353
updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1
@@ -274,4 +356,96 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs):
274356
updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1]
275357
if "pixel_values_RetainedState" in ort_outputs.keys():
276358
updated_inputs["pixel_values"] = ort_outputs["pixel_values_RetainedState"]
359+
if "image_features_RetainedState" in ort_outputs.keys():
360+
updated_inputs["image_features"] = ort_outputs["image_features_RetainedState"]
361+
362+
if "cross_attention_mask" in inputs.keys():
363+
bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape
364+
updated_inputs["cross_attention_mask"] = torch.ones(
365+
(bs, 1, num_images, img_tiles), dtype=torch.int64
366+
).numpy()
367+
368+
for k, v in inputs.items():
369+
if k not in updated_inputs.keys():
370+
updated_inputs[k] = v
277371
return updated_inputs
372+
373+
374+
class InputHandlerInternVL(InputHandlerVLM):
375+
def __init__(self, batch_size, config, image, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer):
376+
self.ctx_len = ctx_len
377+
self.prompt_len = prompt_len
378+
self.max_gen_len = max_gen_len
379+
self.config = config
380+
self.image = image
381+
self.prompt = prompt
382+
self.batch_size = batch_size
383+
self.n_layer = n_layer
384+
self.processor = processor
385+
386+
def prepare_pytorch_inputs(self):
387+
question = "<image>\n" + self.prompt
388+
pixel_values = self.processor.load_image(self.image, max_num=12)
389+
# Chat Template information for prompt preprocessing
390+
messages: List[List[str]] = []
391+
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
392+
prompt = self.processor(pixel_values, question, messages, roles)
393+
inputs = self.processor.tokenizer(prompt, return_tensors="pt")
394+
inputs["pixel_values"] = pixel_values.clone()
395+
396+
if hasattr(self.config, "text_config"):
397+
txt_cfg = self.config.text_config
398+
else:
399+
txt_cfg = self.config.llm_config
400+
401+
num_hidden_layers = txt_cfg.num_hidden_layers
402+
num_key_value_heads = txt_cfg.num_key_value_heads
403+
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
404+
405+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
406+
inputs["past_key_values"] = []
407+
for i in range(num_hidden_layers):
408+
inputs["past_key_values"].append(
409+
(
410+
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
411+
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
412+
)
413+
)
414+
415+
return inputs
416+
417+
def prepare_vlm_ort_inputs(self):
418+
if hasattr(self.config, "text_config"):
419+
txt_cfg = self.config.text_config
420+
else:
421+
txt_cfg = self.config.llm_config
422+
num_hidden_layers = txt_cfg.num_hidden_layers
423+
num_key_value_heads = txt_cfg.num_key_value_heads
424+
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
425+
426+
question = "<image>\n" + self.prompt
427+
pixel_values = self.processor.load_image(self.image, max_num=12)
428+
# Chat Template information for prompt preprocessing
429+
messages: List[List[str]] = []
430+
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
431+
prompt = self.processor(pixel_values, question, messages, roles)
432+
inputs = self.processor.tokenizer(prompt, return_tensors="np")
433+
inputs["pixel_values"] = pixel_values.numpy()
434+
435+
if "attention_mask" in inputs.keys():
436+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
437+
inputs["past_key_values"] = []
438+
439+
vision_inputs = {
440+
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}
441+
}
442+
443+
for i in range(num_hidden_layers):
444+
inputs["past_key." + str(i)] = np.zeros(
445+
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
446+
)
447+
inputs["past_value." + str(i)] = np.zeros(
448+
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
449+
)
450+
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
451+
return vision_inputs, lang_inputs

0 commit comments

Comments
 (0)