16
16
17
17
from QEfficient .utils import constants
18
18
from QEfficient .utils ._utils import IOInfo , get_padding_shape_from_config
19
- from QEfficient . utils . logging_utils import logger
19
+
20
20
21
21
def custom_cumsum (tensor ):
22
22
dim = 0
@@ -238,24 +238,25 @@ def get_specializations(
238
238
kv_offload : bool = False ,
239
239
** compiler_options ,
240
240
):
241
-
242
- if img_size is None and hasattr ( self . config . vision_config , "image_size" ):
243
- img_size = getattr ( self . config . vision_config , "image_size" )
244
- elif img_size is None :
245
- img_size = 1540 # FIXME based on Mistral3 Image size
246
- logger . warning ( "Setting img_size to be 1540, as it was neither passed nor found in vision_config" )
241
+ height = compiler_options . pop ( "height" , None )
242
+ width = compiler_options . pop ( "width" , None )
243
+ if height is None :
244
+ height = self . config . vision_config . image_size
245
+ if width is None :
246
+ width = self . config . vision_config . image_size
247
247
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
248
248
ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
249
249
patch_size = self .config .vision_config .patch_size
250
250
kernel_size = self .config .spatial_merge_size
251
- vision_size = ((img_size // patch_size ) * (img_size // patch_size )) * (batch_size ) // (kernel_size * kernel_size )
251
+ vision_size = ((height // patch_size ) * (width // patch_size )) * (batch_size ) // (kernel_size * kernel_size )
252
252
253
253
vision = [
254
254
{
255
255
"batch_size" : batch_size ,
256
256
"seq_len" : prefill_seq_len ,
257
257
"ctx_len" : ctx_len ,
258
- "image_size" : img_size ,
258
+ "height" : height ,
259
+ "width" : width ,
259
260
"vision_size" : vision_size ,
260
261
}
261
262
]
@@ -264,14 +265,16 @@ def get_specializations(
264
265
"batch_size" : batch_size ,
265
266
"seq_len" : prefill_seq_len ,
266
267
"ctx_len" : ctx_len ,
267
- "image_size" : img_size ,
268
+ "height" : height ,
269
+ "width" : width ,
268
270
"vision_size" : vision_size ,
269
271
},
270
272
{
271
273
"batch_size" : batch_size ,
272
274
"seq_len" : "1" ,
273
275
"ctx_len" : ctx_len ,
274
- "image_size" : img_size ,
276
+ "height" : height ,
277
+ "width" : width ,
275
278
"vision_size" : vision_size ,
276
279
},
277
280
]
@@ -293,7 +296,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
293
296
num_layers = self .config .text_config .num_hidden_layers
294
297
295
298
vision_dynamic_axes = {
296
- "pixel_values" : {0 : "batch_size" , 2 : "image_size " , 3 : "image_size " },
299
+ "pixel_values" : {0 : "batch_size" , 2 : "height " , 3 : "width " },
297
300
}
298
301
lang_dynamic_axes = {
299
302
"input_ids" : {0 : "batch_size" , 1 : "seq_len" },
@@ -338,5 +341,5 @@ def get_inputs_info(self):
338
341
return [
339
342
IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
340
343
IOInfo (name = "attention_mask" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
341
- IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("batch_size" , 3 , "image_size " , "image_size " )),
344
+ IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("batch_size" , 3 , "height " , "width " )),
342
345
]
0 commit comments