|
12 | 12 | import onnxruntime
|
13 | 13 | import torch
|
14 | 14 |
|
| 15 | +from transformers import TextStreamer |
15 | 16 | from QEfficient.generation.text_generation_inference import TextGeneration
|
16 |
| -from QEfficient.utils.generate_inputs import InputHandler |
| 17 | +from QEfficient.utils.generate_inputs import InputHandler, InputHandlerVLM |
| 18 | +from QEfficient.utils._utils import get_padding_shape_vlm |
17 | 19 |
|
18 | 20 |
|
19 | 21 | # TODO: Deprecate this class and encourage the use of `QeffAutoModel...` classes
|
@@ -243,3 +245,125 @@ def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None):
|
243 | 245 | print("Prompt:", repr(self.input_handler.prompt))
|
244 | 246 | print("Completion:", repr(predicted_string))
|
245 | 247 | return execinfo.generated_ids
|
| 248 | + |
| 249 | + |
| 250 | +class ApiRunnerVlm: |
| 251 | + """ |
| 252 | + ApiRunnerVlm class is responsible for running Vision models: |
| 253 | + --------- |
| 254 | +
|
| 255 | + 1. HuggingFace ``PyTorch`` model |
| 256 | + 2. Transformed KV Pytorch Model |
| 257 | + 3. ``ONNX`` model on ONNXRT |
| 258 | + 4. ``ONNX`` model on Cloud AI 100 |
| 259 | + """ |
| 260 | + |
| 261 | + def __init__(self, batch_size, processor, config, image, conversation, prompt, ctx_len, n_layer): |
| 262 | + """ """ |
| 263 | + self.input_handler_vlm = InputHandlerVLM( |
| 264 | + batch_size=batch_size, |
| 265 | + ctx_len=ctx_len, |
| 266 | + config=config, |
| 267 | + image=image, |
| 268 | + conversation=conversation, |
| 269 | + processor=processor, |
| 270 | + n_layer=n_layer, |
| 271 | + prompt=prompt, |
| 272 | + ) |
| 273 | + self.processor = processor |
| 274 | + self.ctx_len = ctx_len |
| 275 | + self.batch_size = batch_size |
| 276 | + self.config = config |
| 277 | + self.gen_len = 20 |
| 278 | + |
| 279 | + def run_vlm_hf_model_on_pytorch(self, model, inputs): |
| 280 | + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) |
| 281 | + py_output = self.processor.tokenizer.decode(output[0, inputs["input_ids"].shape[1] :]).strip() |
| 282 | + print("Original HF Model Outputs (Torch CPU):") |
| 283 | + # print("Prompt:", repr(self.prompt)) |
| 284 | + print("Completion:", repr(py_output)) |
| 285 | + return |
| 286 | + |
| 287 | + def run_vlm_kv_model_on_pytorch(self, model, inputs): |
| 288 | + padding_shape = get_padding_shape_vlm(model.config, self.ctx_len, self.batch_size) |
| 289 | + generation_len = self.ctx_len - inputs["input_ids"].shape[1] |
| 290 | + generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) |
| 291 | + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) |
| 292 | + inputs["past_key_values"] = [] |
| 293 | + for _ in range(model.config.text_config.num_hidden_layers): |
| 294 | + inputs["past_key_values"].append( |
| 295 | + ( |
| 296 | + torch.zeros(padding_shape, dtype=torch.float32), |
| 297 | + torch.zeros(padding_shape, dtype=torch.float32), |
| 298 | + ) |
| 299 | + ) |
| 300 | + outputs = model(**inputs) |
| 301 | + inputs["input_ids"] = outputs[0].argmax(2) |
| 302 | + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) |
| 303 | + finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id |
| 304 | + inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 |
| 305 | + streamer = TextStreamer(self.processor.tokenizer) |
| 306 | + streamer.put(inputs["input_ids"]) |
| 307 | + for num_token in range(self.gen_len): |
| 308 | + outputs = model(**inputs) |
| 309 | + inputs["input_ids"] = outputs[0].argmax(2) |
| 310 | + inputs["position_ids"] += 1 |
| 311 | + streamer.put(inputs["input_ids"]) |
| 312 | + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) |
| 313 | + finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id |
| 314 | + if finished_sequences.all(): |
| 315 | + break |
| 316 | + # generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| 317 | + streamer.end() |
| 318 | + return generated_ids[0] |
| 319 | + |
| 320 | + def run_ort_session(self, inputs, session) -> dict: |
| 321 | + """ |
| 322 | + Function responsible for running onnxrt session with given inputs and passing retained state outputs to be used for next iteration inputs |
| 323 | +
|
| 324 | + ``Mandatory`` Args: |
| 325 | + :inputs (Dict): |
| 326 | + :session (onnxruntime.capi.onnxruntime_inference_collection.InferenceSession): |
| 327 | +
|
| 328 | + Return: |
| 329 | + :Dict: Numpy outputs of Onnx model |
| 330 | + """ |
| 331 | + output_names = [x.name for x in session.get_outputs()] |
| 332 | + session_input_names = [x.name for x in session.get_inputs()] |
| 333 | + session_inputs = {} |
| 334 | + for inp_name in session_input_names: |
| 335 | + if inp_name in inputs.keys(): |
| 336 | + session_inputs[inp_name] = inputs[inp_name] |
| 337 | + outputs_data = session.run(output_names, session_inputs) |
| 338 | + ort_outputs = dict(zip(output_names, outputs_data)) |
| 339 | + return ort_outputs |
| 340 | + |
| 341 | + def run_vlm_kv_model_on_ort(self, model_path): |
| 342 | + m = onnx.load(model_path, load_external_data=False) |
| 343 | + # NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required |
| 344 | + added_initializers = {} |
| 345 | + for node in m.graph.node: |
| 346 | + if node.op_type == "Constant": |
| 347 | + np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t, os.path.dirname(model_path)) |
| 348 | + if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647: |
| 349 | + added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy( |
| 350 | + np.array(0, np_tensor.dtype) |
| 351 | + ) |
| 352 | + session_options = onnxruntime.SessionOptions() |
| 353 | + for name, value in added_initializers.items(): |
| 354 | + session_options.add_initializer(name, value) |
| 355 | + session = onnxruntime.InferenceSession(model_path, session_options) |
| 356 | + generated_ids = [] |
| 357 | + inputs = self.input_handler_vlm.prepare_vlm_ort_inputs() |
| 358 | + ort_outputs = self.run_ort_session(inputs, session=session) |
| 359 | + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) |
| 360 | + for _ in range(1, self.gen_len): |
| 361 | + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) |
| 362 | + inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs) |
| 363 | + ort_outputs = self.run_ort_session(inputs, session) |
| 364 | + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) |
| 365 | + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) |
| 366 | + generated_ids = np.concatenate(generated_ids, axis=1) |
| 367 | + predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| 368 | + print("Completion:", repr(predicted_string)) |
| 369 | + return generated_ids |
0 commit comments