Skip to content

Commit 680e72b

Browse files
committed
VLM e2e test pipeline
Signed-off-by: Ann <[email protected]>
1 parent fc89e8b commit 680e72b

File tree

6 files changed

+760
-2
lines changed

6 files changed

+760
-2
lines changed

QEfficient/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@
2222
onnx_exists,
2323
padding_check_and_fix,
2424
qpc_exists,
25+
get_padding_shape_vlm,
26+
get_num_layers_vlm,
2527
)

QEfficient/utils/_utils.py

+44
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,50 @@ def get_num_layers_from_config(config):
352352
return n_layer
353353

354354

355+
def get_num_layers_vlm(config):
356+
"""
357+
Gets number of layers from model config of VLM
358+
--------
359+
360+
:config: AutoConfig from pretrained model.
361+
362+
Return:
363+
number of layers of text and vision part
364+
"""
365+
366+
if hasattr(config, "llm_config") and hasattr(config, "vision_config"): # Intern
367+
n_layers_text = config.llm_config.num_hidden_layers
368+
n_layers_vision = config.vision_config.num_hidden_layers
369+
elif hasattr(config, "text_config") and hasattr(config, "vision_config"): # Llava, Mllama
370+
n_layers_text = config.text_config.num_hidden_layers
371+
n_layers_vision = config.vision_config.num_hidden_layers
372+
373+
return (n_layers_text, n_layers_vision)
374+
375+
376+
def get_padding_shape_vlm(config, ctx_len, batch_size=1):
377+
"""
378+
Gets padding dims for VLM models- number of kv heads and d_head
379+
and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size)
380+
required for initialization of past_key_values
381+
--------
382+
383+
:config: AutoConfig from pretrained model.
384+
:batch_size: int. number of input prompts used to create inputs
385+
:seq_len: int. sequence length to run the model for.
386+
387+
Return:
388+
List[int, int, int, int]
389+
"""
390+
if hasattr(config, "architectures") and "LlavaForConditionalGeneration" in config.architectures:
391+
n_heads = config.text_config.num_key_value_heads
392+
d_head = config.text_config.hidden_size // config.text_config.num_attention_heads
393+
padding_shape = [batch_size, n_heads, ctx_len, d_head]
394+
elif hasattr(config, "architectures") and "MllamaForConditionalGeneration" in config.architectures:
395+
padding_shape = []
396+
return padding_shape
397+
398+
355399
def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
356400
"""
357401
Executes the give command using subprocess.

QEfficient/utils/generate_inputs.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import numpy as np
99
import torch
1010

11-
from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix
11+
from QEfficient.utils import (
12+
get_num_layers_from_config,
13+
get_padding_shape_from_config,
14+
padding_check_and_fix,
15+
get_padding_shape_vlm,
16+
)
1217

1318

1419
class InputHandler:
@@ -198,3 +203,75 @@ def update_ort_outputs(self, ort_outputs):
198203
outputs["logits"] = ort_outputs["logits"]
199204

200205
return outputs
206+
207+
208+
class InputHandlerVLM:
209+
def __init__(self, batch_size, config, image, conversation, processor, prompt, ctx_len, n_layer):
210+
self.ctx_len = ctx_len
211+
self.config = config
212+
self.image = image
213+
self.prompt = prompt
214+
self.batch_size = batch_size
215+
self.padding_shape = get_padding_shape_vlm(config, ctx_len, batch_size)
216+
self.n_layer = n_layer
217+
self.processor = processor
218+
self.conversation = conversation
219+
220+
def prepare_vlm_ort_inputs(self):
221+
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="np")
222+
if "attention_mask" in inputs.keys():
223+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
224+
inputs["past_key_values"] = []
225+
for i in range(self.n_layer[0]):
226+
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
227+
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
228+
229+
return inputs
230+
231+
def update_vlm_ort_outputs(self, ort_outputs):
232+
"""
233+
Function responsible for updating ONNXRT session outputs.
234+
235+
``Mandatory`` Args:
236+
:ort_outputs (Dict): Numpy outputs of Onnx model from current iteration
237+
238+
Return:
239+
updated_outputs (Dict): Updated past_key_values, logits, pixel_values
240+
"""
241+
242+
present_key_values = []
243+
for i in range(self.n_layer[0]):
244+
if "past_key." + str(i) + "_RetainedState" in ort_outputs:
245+
present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"])
246+
if "past_value." + str(i) + "_RetainedState" in ort_outputs:
247+
present_key_values.append(ort_outputs["past_value." + str(i) + "_RetainedState"])
248+
249+
outputs = {}
250+
outputs["past_key_values"] = present_key_values
251+
outputs["logits"] = ort_outputs["logits"]
252+
outputs["pixel_values_RetainedState"] = (
253+
ort_outputs["pixel_values_RetainedState"] if "pixel_values_RetainedState" in ort_outputs else None
254+
)
255+
return outputs
256+
257+
def update_vlm_ort_inputs(self, inputs, ort_outputs):
258+
"""
259+
Function responsible for updating Prefill stage inputs to create inputs for decode stage inputs for ONNX model to be run on ONNXRT.
260+
261+
``Mandatory`` Args:
262+
:inputs (Dict): NumPy inputs of Onnx model from previous iteration
263+
:ort_outputs (Dict): Numpy outputs of Onnx model from previous iteration
264+
265+
Return:
266+
:Dict: Updated input_ids, position_ids, pixel_values and past_key_values
267+
"""
268+
269+
updated_inputs = {}
270+
updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1)
271+
updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1
272+
for i in range(self.n_layer[0]):
273+
updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2]
274+
updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1]
275+
if "pixel_values_RetainedState" in ort_outputs.keys():
276+
updated_inputs["pixel_values"] = ort_outputs["pixel_values_RetainedState"]
277+
return updated_inputs

QEfficient/utils/run_utils.py

+125-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import onnxruntime
1313
import torch
1414

15+
from transformers import TextStreamer
1516
from QEfficient.generation.text_generation_inference import TextGeneration
16-
from QEfficient.utils.generate_inputs import InputHandler
17+
from QEfficient.utils.generate_inputs import InputHandler, InputHandlerVLM
18+
from QEfficient.utils._utils import get_padding_shape_vlm
1719

1820

1921
# TODO: Deprecate this class and encourage the use of `QeffAutoModel...` classes
@@ -243,3 +245,125 @@ def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None):
243245
print("Prompt:", repr(self.input_handler.prompt))
244246
print("Completion:", repr(predicted_string))
245247
return execinfo.generated_ids
248+
249+
250+
class ApiRunnerVlm:
251+
"""
252+
ApiRunnerVlm class is responsible for running Vision models:
253+
---------
254+
255+
1. HuggingFace ``PyTorch`` model
256+
2. Transformed KV Pytorch Model
257+
3. ``ONNX`` model on ONNXRT
258+
4. ``ONNX`` model on Cloud AI 100
259+
"""
260+
261+
def __init__(self, batch_size, processor, config, image, conversation, prompt, ctx_len, n_layer):
262+
""" """
263+
self.input_handler_vlm = InputHandlerVLM(
264+
batch_size=batch_size,
265+
ctx_len=ctx_len,
266+
config=config,
267+
image=image,
268+
conversation=conversation,
269+
processor=processor,
270+
n_layer=n_layer,
271+
prompt=prompt,
272+
)
273+
self.processor = processor
274+
self.ctx_len = ctx_len
275+
self.batch_size = batch_size
276+
self.config = config
277+
self.gen_len = 20
278+
279+
def run_vlm_hf_model_on_pytorch(self, model, inputs):
280+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
281+
py_output = self.processor.tokenizer.decode(output[0, inputs["input_ids"].shape[1] :]).strip()
282+
print("Original HF Model Outputs (Torch CPU):")
283+
# print("Prompt:", repr(self.prompt))
284+
print("Completion:", repr(py_output))
285+
return
286+
287+
def run_vlm_kv_model_on_pytorch(self, model, inputs):
288+
padding_shape = get_padding_shape_vlm(model.config, self.ctx_len, self.batch_size)
289+
generation_len = self.ctx_len - inputs["input_ids"].shape[1]
290+
generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id)
291+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
292+
inputs["past_key_values"] = []
293+
for _ in range(model.config.text_config.num_hidden_layers):
294+
inputs["past_key_values"].append(
295+
(
296+
torch.zeros(padding_shape, dtype=torch.float32),
297+
torch.zeros(padding_shape, dtype=torch.float32),
298+
)
299+
)
300+
outputs = model(**inputs)
301+
inputs["input_ids"] = outputs[0].argmax(2)
302+
generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
303+
finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id
304+
inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1
305+
streamer = TextStreamer(self.processor.tokenizer)
306+
streamer.put(inputs["input_ids"])
307+
for num_token in range(self.gen_len):
308+
outputs = model(**inputs)
309+
inputs["input_ids"] = outputs[0].argmax(2)
310+
inputs["position_ids"] += 1
311+
streamer.put(inputs["input_ids"])
312+
generated_ids[:, num_token] = inputs["input_ids"].squeeze(1)
313+
finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id
314+
if finished_sequences.all():
315+
break
316+
# generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
317+
streamer.end()
318+
return generated_ids[0]
319+
320+
def run_ort_session(self, inputs, session) -> dict:
321+
"""
322+
Function responsible for running onnxrt session with given inputs and passing retained state outputs to be used for next iteration inputs
323+
324+
``Mandatory`` Args:
325+
:inputs (Dict):
326+
:session (onnxruntime.capi.onnxruntime_inference_collection.InferenceSession):
327+
328+
Return:
329+
:Dict: Numpy outputs of Onnx model
330+
"""
331+
output_names = [x.name for x in session.get_outputs()]
332+
session_input_names = [x.name for x in session.get_inputs()]
333+
session_inputs = {}
334+
for inp_name in session_input_names:
335+
if inp_name in inputs.keys():
336+
session_inputs[inp_name] = inputs[inp_name]
337+
outputs_data = session.run(output_names, session_inputs)
338+
ort_outputs = dict(zip(output_names, outputs_data))
339+
return ort_outputs
340+
341+
def run_vlm_kv_model_on_ort(self, model_path):
342+
m = onnx.load(model_path, load_external_data=False)
343+
# NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required
344+
added_initializers = {}
345+
for node in m.graph.node:
346+
if node.op_type == "Constant":
347+
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t, os.path.dirname(model_path))
348+
if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647:
349+
added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy(
350+
np.array(0, np_tensor.dtype)
351+
)
352+
session_options = onnxruntime.SessionOptions()
353+
for name, value in added_initializers.items():
354+
session_options.add_initializer(name, value)
355+
session = onnxruntime.InferenceSession(model_path, session_options)
356+
generated_ids = []
357+
inputs = self.input_handler_vlm.prepare_vlm_ort_inputs()
358+
ort_outputs = self.run_ort_session(inputs, session=session)
359+
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
360+
for _ in range(1, self.gen_len):
361+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
362+
inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs)
363+
ort_outputs = self.run_ort_session(inputs, session)
364+
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
365+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
366+
generated_ids = np.concatenate(generated_ids, axis=1)
367+
predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
368+
print("Completion:", repr(predicted_string))
369+
return generated_ids

0 commit comments

Comments
 (0)