4
4
# SPDX-License-Identifier: BSD-3-Clause
5
5
#
6
6
# -----------------------------------------------------------------------------
7
+ from typing import List
7
8
8
9
import numpy as np
9
10
import torch
12
13
get_num_layers_from_config ,
13
14
get_padding_shape_from_config ,
14
15
padding_check_and_fix ,
15
- get_padding_shape_vlm ,
16
16
)
17
17
18
18
@@ -206,27 +206,108 @@ def update_ort_outputs(self, ort_outputs):
206
206
207
207
208
208
class InputHandlerVLM :
209
- def __init__ (self , batch_size , config , image , conversation , processor , prompt , ctx_len , n_layer ):
209
+ def __init__ (
210
+ self , batch_size , config , image , conversation , processor , prompt , prompt_len , ctx_len , max_gen_len , n_layer
211
+ ):
210
212
self .ctx_len = ctx_len
213
+ self .prompt_len = prompt_len
214
+ self .max_gen_len = max_gen_len
211
215
self .config = config
212
216
self .image = image
213
217
self .prompt = prompt
214
218
self .batch_size = batch_size
215
- self .padding_shape = get_padding_shape_vlm (config , ctx_len , batch_size )
216
219
self .n_layer = n_layer
217
220
self .processor = processor
218
221
self .conversation = conversation
219
222
223
+ def prepare_pytorch_inputs (self ):
224
+ """
225
+ Function responsible for creating Prefill stage tensor inputs for PyTorch model.
226
+
227
+ Return:
228
+ :Dict: input_ids, position_ids, past_key_values
229
+ """
230
+ inputs = self .processor (images = self .image , text = self .prompt , return_tensors = "pt" )
231
+ if hasattr (self .config , "text_config" ):
232
+ txt_cfg = self .config .text_config
233
+ else :
234
+ txt_cfg = self .config .llm_config
235
+
236
+ num_hidden_layers = txt_cfg .num_hidden_layers
237
+ num_key_value_heads = txt_cfg .num_key_value_heads
238
+ head_dim = txt_cfg .hidden_size // txt_cfg .num_attention_heads
239
+ if hasattr (txt_cfg , "cross_attention_layers" ):
240
+ cross_attention_layers = txt_cfg .cross_attention_layers
241
+
242
+ vis_cfg = self .config .vision_config
243
+ num_patches = (vis_cfg .image_size // vis_cfg .patch_size ) ** 2 + 1
244
+ image_tokens_len = vis_cfg .max_num_tiles * num_patches
245
+
246
+ inputs ["position_ids" ] = inputs .pop ("attention_mask" ).cumsum (1 ) - 1
247
+ inputs ["past_key_values" ] = []
248
+ for i in range (num_hidden_layers ):
249
+ # Specific to mllama as of now
250
+ if hasattr (txt_cfg , "cross_attention_layers" ) and i in cross_attention_layers :
251
+ idx = cross_attention_layers .index (i )
252
+ assert idx == ((i - 3 ) // 5 ), f"{ i } , { (i - 3 ) // 5 } "
253
+ inputs ["past_key_values" ].append (
254
+ (
255
+ torch .zeros (1 , num_key_value_heads , image_tokens_len , head_dim ),
256
+ torch .zeros (1 , num_key_value_heads , image_tokens_len , head_dim ),
257
+ )
258
+ )
259
+ else :
260
+ inputs ["past_key_values" ].append (
261
+ (
262
+ torch .zeros (1 , num_key_value_heads , self .ctx_len , head_dim ),
263
+ torch .zeros (1 , num_key_value_heads , self .ctx_len , head_dim ),
264
+ )
265
+ )
266
+
267
+ return inputs
268
+
220
269
def prepare_vlm_ort_inputs (self ):
270
+ if hasattr (self .config , "text_config" ):
271
+ txt_cfg = self .config .text_config
272
+ else :
273
+ txt_cfg = self .config .llm_config
274
+ num_hidden_layers = txt_cfg .num_hidden_layers
275
+ num_key_value_heads = txt_cfg .num_key_value_heads
276
+ head_dim = txt_cfg .hidden_size // txt_cfg .num_attention_heads
277
+ if hasattr (txt_cfg , "cross_attention_layers" ):
278
+ cross_attention_layers = txt_cfg .cross_attention_layers
279
+ vis_cfg = self .config .vision_config
280
+ num_patches = (vis_cfg .image_size // vis_cfg .patch_size ) ** 2 + 1
281
+ image_tokens_len = vis_cfg .max_num_tiles * num_patches
282
+
221
283
inputs = self .processor (images = self .image , text = self .prompt , return_tensors = "np" )
222
284
if "attention_mask" in inputs .keys ():
223
- inputs ["position_ids" ] = inputs .pop ("attention_mask" ).cumsum (1 )
285
+ inputs ["position_ids" ] = inputs .pop ("attention_mask" ).cumsum (1 ) - 1
224
286
inputs ["past_key_values" ] = []
225
- for i in range (self .n_layer [0 ]):
226
- inputs ["past_key." + str (i )] = np .zeros ((self .padding_shape ), dtype = np .float32 )
227
- inputs ["past_value." + str (i )] = np .zeros ((self .padding_shape ), dtype = np .float32 )
228
287
229
- return inputs
288
+ vision_inputs = {
289
+ k : v for k , v in inputs .items () if k in {"pixel_values" , "aspect_ratio_ids" , "aspect_ratio_mask" }
290
+ }
291
+
292
+ for i in range (num_hidden_layers ):
293
+ if hasattr (txt_cfg , "cross_attention_layers" ) and i in cross_attention_layers :
294
+ idx = cross_attention_layers .index (i )
295
+ assert idx == ((i - 3 ) // 5 ), f"{ i } , { (i - 3 ) // 5 } "
296
+ inputs ["past_key." + str (i )] = np .zeros (
297
+ (self .batch_size , num_key_value_heads , image_tokens_len , head_dim ), dtype = np .float32
298
+ )
299
+ inputs ["past_value." + str (i )] = np .zeros (
300
+ (self .batch_size , num_key_value_heads , image_tokens_len , head_dim ), dtype = np .float32
301
+ )
302
+ else :
303
+ inputs ["past_key." + str (i )] = np .zeros (
304
+ (self .batch_size , num_key_value_heads , self .ctx_len , head_dim ), dtype = np .float32
305
+ )
306
+ inputs ["past_value." + str (i )] = np .zeros (
307
+ (self .batch_size , num_key_value_heads , self .ctx_len , head_dim ), dtype = np .float32
308
+ )
309
+ lang_inputs = {k : v for k , v in inputs .items () if k not in vision_inputs }
310
+ return vision_inputs , lang_inputs
230
311
231
312
def update_vlm_ort_outputs (self , ort_outputs ):
232
313
"""
@@ -238,7 +319,6 @@ def update_vlm_ort_outputs(self, ort_outputs):
238
319
Return:
239
320
updated_outputs (Dict): Updated past_key_values, logits, pixel_values
240
321
"""
241
-
242
322
present_key_values = []
243
323
for i in range (self .n_layer [0 ]):
244
324
if "past_key." + str (i ) + "_RetainedState" in ort_outputs :
@@ -252,6 +332,9 @@ def update_vlm_ort_outputs(self, ort_outputs):
252
332
outputs ["pixel_values_RetainedState" ] = (
253
333
ort_outputs ["pixel_values_RetainedState" ] if "pixel_values_RetainedState" in ort_outputs else None
254
334
)
335
+ outputs ["image_features_RetainedState" ] = (
336
+ ort_outputs ["image_features_RetainedState" ] if "image_features_RetainedState" in ort_outputs else None
337
+ )
255
338
return outputs
256
339
257
340
def update_vlm_ort_inputs (self , inputs , ort_outputs ):
@@ -265,7 +348,6 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs):
265
348
Return:
266
349
:Dict: Updated input_ids, position_ids, pixel_values and past_key_values
267
350
"""
268
-
269
351
updated_inputs = {}
270
352
updated_inputs ["input_ids" ] = ort_outputs ["logits" ].argmax (- 1 )
271
353
updated_inputs ["position_ids" ] = np .max (inputs ["position_ids" ], axis = 1 , keepdims = True ) + 1
@@ -274,4 +356,96 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs):
274
356
updated_inputs ["past_value." + str (i )] = ort_outputs ["past_key_values" ][i * 2 + 1 ]
275
357
if "pixel_values_RetainedState" in ort_outputs .keys ():
276
358
updated_inputs ["pixel_values" ] = ort_outputs ["pixel_values_RetainedState" ]
359
+ if "image_features_RetainedState" in ort_outputs .keys ():
360
+ updated_inputs ["image_features" ] = ort_outputs ["image_features_RetainedState" ]
361
+
362
+ if "cross_attention_mask" in inputs .keys ():
363
+ bs , _ , num_images , img_tiles = inputs ["cross_attention_mask" ].shape
364
+ updated_inputs ["cross_attention_mask" ] = torch .ones (
365
+ (bs , 1 , num_images , img_tiles ), dtype = torch .int64
366
+ ).numpy ()
367
+
368
+ for k , v in inputs .items ():
369
+ if k not in updated_inputs .keys ():
370
+ updated_inputs [k ] = v
277
371
return updated_inputs
372
+
373
+
374
+ class InputHandlerInternVL (InputHandlerVLM ):
375
+ def __init__ (self , batch_size , config , image , processor , prompt , prompt_len , ctx_len , max_gen_len , n_layer ):
376
+ self .ctx_len = ctx_len
377
+ self .prompt_len = prompt_len
378
+ self .max_gen_len = max_gen_len
379
+ self .config = config
380
+ self .image = image
381
+ self .prompt = prompt
382
+ self .batch_size = batch_size
383
+ self .n_layer = n_layer
384
+ self .processor = processor
385
+
386
+ def prepare_pytorch_inputs (self ):
387
+ question = "<image>\n " + self .prompt
388
+ pixel_values = self .processor .load_image (self .image , max_num = 12 )
389
+ # Chat Template information for prompt preprocessing
390
+ messages : List [List [str ]] = []
391
+ roles = ("<|im_start|>user\n " , "<|im_start|>assistant\n " )
392
+ prompt = self .processor (pixel_values , question , messages , roles )
393
+ inputs = self .processor .tokenizer (prompt , return_tensors = "pt" )
394
+ inputs ["pixel_values" ] = pixel_values .clone ()
395
+
396
+ if hasattr (self .config , "text_config" ):
397
+ txt_cfg = self .config .text_config
398
+ else :
399
+ txt_cfg = self .config .llm_config
400
+
401
+ num_hidden_layers = txt_cfg .num_hidden_layers
402
+ num_key_value_heads = txt_cfg .num_key_value_heads
403
+ head_dim = txt_cfg .hidden_size // txt_cfg .num_attention_heads
404
+
405
+ inputs ["position_ids" ] = inputs .pop ("attention_mask" ).cumsum (1 ) - 1
406
+ inputs ["past_key_values" ] = []
407
+ for i in range (num_hidden_layers ):
408
+ inputs ["past_key_values" ].append (
409
+ (
410
+ torch .zeros (1 , num_key_value_heads , self .ctx_len , head_dim ),
411
+ torch .zeros (1 , num_key_value_heads , self .ctx_len , head_dim ),
412
+ )
413
+ )
414
+
415
+ return inputs
416
+
417
+ def prepare_vlm_ort_inputs (self ):
418
+ if hasattr (self .config , "text_config" ):
419
+ txt_cfg = self .config .text_config
420
+ else :
421
+ txt_cfg = self .config .llm_config
422
+ num_hidden_layers = txt_cfg .num_hidden_layers
423
+ num_key_value_heads = txt_cfg .num_key_value_heads
424
+ head_dim = txt_cfg .hidden_size // txt_cfg .num_attention_heads
425
+
426
+ question = "<image>\n " + self .prompt
427
+ pixel_values = self .processor .load_image (self .image , max_num = 12 )
428
+ # Chat Template information for prompt preprocessing
429
+ messages : List [List [str ]] = []
430
+ roles = ("<|im_start|>user\n " , "<|im_start|>assistant\n " )
431
+ prompt = self .processor (pixel_values , question , messages , roles )
432
+ inputs = self .processor .tokenizer (prompt , return_tensors = "np" )
433
+ inputs ["pixel_values" ] = pixel_values .numpy ()
434
+
435
+ if "attention_mask" in inputs .keys ():
436
+ inputs ["position_ids" ] = inputs .pop ("attention_mask" ).cumsum (1 ) - 1
437
+ inputs ["past_key_values" ] = []
438
+
439
+ vision_inputs = {
440
+ k : v for k , v in inputs .items () if k in {"pixel_values" , "aspect_ratio_ids" , "aspect_ratio_mask" }
441
+ }
442
+
443
+ for i in range (num_hidden_layers ):
444
+ inputs ["past_key." + str (i )] = np .zeros (
445
+ (self .batch_size , num_key_value_heads , self .ctx_len , head_dim ), dtype = np .float32
446
+ )
447
+ inputs ["past_value." + str (i )] = np .zeros (
448
+ (self .batch_size , num_key_value_heads , self .ctx_len , head_dim ), dtype = np .float32
449
+ )
450
+ lang_inputs = {k : v for k , v in inputs .items () if k not in vision_inputs }
451
+ return vision_inputs , lang_inputs
0 commit comments