Skip to content

VLM Test pipeline added #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def export(
inputs = self.model.get_dummy_inputs()
dynamic_axes = self.model.get_onnx_dynamic_axes()
output_names = self.model.get_output_names()
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)

def compile(
self,
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
#
# -----------------------------------------------------------------------------

from QEfficient.transformers.quantizers.auto import ( # noqa: F401
replace_transformers_quantizers,
undo_transformers_quantizers,
)
from QEfficient.utils._utils import ( # noqa: F401
check_and_assign_cache_dir,
dump_qconfig,
get_num_layers_from_config,
get_onnx_dir_name,
get_padding_shape_from_config,
get_qpc_dir_path,
hf_download,
load_hf_tokenizer,
login_and_download_hf_lm,
onnx_exists,
padding_check_and_fix,
qpc_exists,
get_padding_shape_vlm,
get_num_layers_vlm,
)

Check failure on line 27 in QEfficient/utils/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/utils/__init__.py:8:1: I001 Import block is un-sorted or un-formatted
41 changes: 41 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,47 @@ def get_num_layers_from_config(config):

return n_layer

def get_num_layers_vlm(config):
"""
Gets number of layers from model config of VLM
--------

:config: AutoConfig from pretrained model.

Return:
number of layers of text and vision part
"""

if hasattr(config, "llm_config") and hasattr (config, "vision_config"): #Intern
n_layers_text = config.llm_config.num_hidden_layers
n_layers_vision = config.vision_config.num_hidden_layers
elif hasattr(config, "text_config") and hasattr(config, "vision_config"): #Llava, Mllama
n_layers_text = config.text_config.num_hidden_layers
n_layers_vision = config.vision_config.num_hidden_layers

return (n_layers_text, n_layers_vision)

def get_padding_shape_vlm(config, ctx_len, batch_size=1):
"""
Gets padding dims for VLM models- number of kv heads and d_head
and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size)
required for initialization of past_key_values
--------

:config: AutoConfig from pretrained model.
:batch_size: int. number of input prompts used to create inputs
:seq_len: int. sequence length to run the model for.

Return:
List[int, int, int, int]
"""
if hasattr(config, "architectures") and "LlavaForConditionalGeneration" in config.architectures:
n_heads = config.text_config.num_key_value_heads
d_head = config.text_config.hidden_size // config.text_config.num_attention_heads
padding_shape = [batch_size, n_heads, ctx_len, d_head]
elif hasattr(config, "architectures") and "MllamaForConditionalGeneration" in config.architectures:
padding_shape=[]
return padding_shape

def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
"""
Expand Down
72 changes: 71 additions & 1 deletion QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
#
# -----------------------------------------------------------------------------

import numpy as np
import torch

from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix
from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix, get_padding_shape_vlm

Check failure on line 11 in QEfficient/utils/generate_inputs.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/utils/generate_inputs.py:8:1: I001 Import block is un-sorted or un-formatted


class InputHandler:
Expand Down Expand Up @@ -198,3 +198,73 @@
outputs["logits"] = ort_outputs["logits"]

return outputs

class InputHandlerVLM:

def __init__(self, batch_size, config, image, conversation, processor, prompt, ctx_len, n_layer):
self.ctx_len = ctx_len
self.config = config
self.image = image
self.prompt = prompt
self.batch_size = batch_size
self.padding_shape = get_padding_shape_vlm(config, ctx_len, batch_size)
self.n_layer = n_layer
self.processor = processor
self.conversation = conversation

def prepare_vlm_ort_inputs(self):
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="np")
if "attention_mask" in inputs.keys():
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
inputs["past_key_values"] = []
for i in range(self.n_layer[0]):
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)

return inputs

def update_vlm_ort_outputs(self, ort_outputs):
"""
Function responsible for updating ONNXRT session outputs.

``Mandatory`` Args:
:ort_outputs (Dict): Numpy outputs of Onnx model from current iteration

Return:
updated_outputs (Dict): Updated past_key_values, logits, pixel_values
"""

present_key_values = []
for i in range(self.n_layer[0]):
if "past_key." + str(i) + "_RetainedState" in ort_outputs:
present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"])
if "past_value." + str(i) + "_RetainedState" in ort_outputs:
present_key_values.append(ort_outputs["past_value." + str(i) + "_RetainedState"])

outputs = {}
outputs["past_key_values"] = present_key_values
outputs["logits"] = ort_outputs["logits"]
outputs['pixel_values_RetainedState'] = ort_outputs['pixel_values_RetainedState'] if "pixel_values_RetainedState" in ort_outputs else None
return outputs

def update_vlm_ort_inputs(self, inputs, ort_outputs):
"""
Function responsible for updating Prefill stage inputs to create inputs for decode stage inputs for ONNX model to be run on ONNXRT.

``Mandatory`` Args:
:inputs (Dict): NumPy inputs of Onnx model from previous iteration
:ort_outputs (Dict): Numpy outputs of Onnx model from previous iteration

Return:
:Dict: Updated input_ids, position_ids, pixel_values and past_key_values
"""

updated_inputs = {}
updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1)
updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1
for i in range(self.n_layer[0]):
updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2]
updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1]
if "pixel_values_RetainedState" in ort_outputs.keys():
updated_inputs["pixel_values"] = ort_outputs["pixel_values_RetainedState"]
return updated_inputs
131 changes: 129 additions & 2 deletions QEfficient/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
#
# -----------------------------------------------------------------------------

import os

import numpy as np
import onnx
import onnxruntime
import torch

from transformers import TextStreamer
from QEfficient.generation.text_generation_inference import TextGeneration
from QEfficient.utils.generate_inputs import InputHandler
from QEfficient.utils.generate_inputs import InputHandler, InputHandlerVLM
from QEfficient.utils._utils import get_padding_shape_vlm

Check failure on line 17 in QEfficient/utils/run_utils.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/utils/run_utils.py:8:1: I001 Import block is un-sorted or un-formatted


# TODO: Deprecate this class and encourage the use of `QeffAutoModel...` classes
Expand Down Expand Up @@ -243,3 +244,129 @@
print("Prompt:", repr(self.input_handler.prompt))
print("Completion:", repr(predicted_string))
return execinfo.generated_ids

class ApiRunnerVlm:

"""
ApiRunnerVlm class is responsible for running Vision models:
---------

1. HuggingFace ``PyTorch`` model
2. Transformed KV Pytorch Model
3. ``ONNX`` model on ONNXRT
4. ``ONNX`` model on Cloud AI 100
"""

def __init__(self, batch_size, processor, config, image, conversation, prompt, ctx_len, n_layer):
"""

"""
self.input_handler_vlm = InputHandlerVLM(
batch_size=batch_size,
ctx_len=ctx_len,
config=config,
image = image,
conversation = conversation,
processor = processor,
n_layer = n_layer,
prompt = prompt,
)
self.processor = processor
self.ctx_len = ctx_len
self.batch_size = batch_size
self.config = config
self.gen_len = 20

def run_vlm_hf_model_on_pytorch(self, model, inputs):
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
py_output = self.processor.tokenizer.decode(output[0, inputs['input_ids'].shape[1]:]).strip()
print("Original HF Model Outputs (Torch CPU):")
# print("Prompt:", repr(self.prompt))
print("Completion:", repr(py_output))
return

def run_vlm_kv_model_on_pytorch(self, model, inputs):

padding_shape = get_padding_shape_vlm(model.config, self.ctx_len, self.batch_size)
generation_len = self.ctx_len - inputs["input_ids"].shape[1]
generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id)
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
inputs["past_key_values"] = []
for _ in range( model.config.text_config.num_hidden_layers):
inputs["past_key_values"].append(
(
torch.zeros(padding_shape, dtype=torch.float32),
torch.zeros(padding_shape, dtype=torch.float32),
)
)
outputs = model(**inputs)
inputs["input_ids"] = outputs[0].argmax(2)
generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id
inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1
streamer = TextStreamer(self.processor.tokenizer)
streamer.put(inputs["input_ids"])
for num_token in range(self.gen_len):
outputs = model(**inputs)
inputs["input_ids"] = outputs[0].argmax(2)
inputs["position_ids"] += 1
streamer.put(inputs["input_ids"])
generated_ids[:,num_token] = inputs["input_ids"].squeeze(1)
finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id
if finished_sequences.all():
break
# generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
streamer.end()
breakpoint()
return generated_ids[0]

def run_ort_session(self, inputs, session) -> dict:
"""
Function responsible for running onnxrt session with given inputs and passing retained state outputs to be used for next iteration inputs

``Mandatory`` Args:
:inputs (Dict):
:session (onnxruntime.capi.onnxruntime_inference_collection.InferenceSession):

Return:
:Dict: Numpy outputs of Onnx model
"""
output_names = [x.name for x in session.get_outputs()]
session_input_names = [x.name for x in session.get_inputs()]
session_inputs = {}
for inp_name in session_input_names:
if inp_name in inputs.keys():
session_inputs[inp_name] = inputs[inp_name]
outputs_data = session.run(output_names, session_inputs)
ort_outputs = dict(zip(output_names, outputs_data))
return ort_outputs

def run_vlm_kv_model_on_ort(self, model_path):
m = onnx.load(model_path, load_external_data=False)
# NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required
added_initializers = {}
for node in m.graph.node:
if node.op_type == "Constant":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t, os.path.dirname(model_path))
if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647:
added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy(
np.array(0, np_tensor.dtype)
)
session_options = onnxruntime.SessionOptions()
for name, value in added_initializers.items():
session_options.add_initializer(name, value)
session = onnxruntime.InferenceSession(model_path, session_options)
generated_ids = []
inputs = self.input_handler_vlm.prepare_vlm_ort_inputs()
ort_outputs = self.run_ort_session(inputs, session = session)
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
for _ in range(1, self.gen_len):
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs)
ort_outputs = self.run_ort_session(inputs, session)
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
generated_ids = np.concatenate(generated_ids, axis=1)
predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print("Completion:", repr(predicted_string))
return generated_ids
Loading
Loading