16
16
17
17
from QEfficient .utils import constants
18
18
from QEfficient .utils ._utils import IOInfo , get_padding_shape_from_config
19
-
19
+ from QEfficient . utils . logging_utils import logger
20
20
21
21
def custom_cumsum (tensor ):
22
22
dim = 0
@@ -238,25 +238,24 @@ def get_specializations(
238
238
kv_offload : bool = False ,
239
239
** compiler_options ,
240
240
):
241
- height = compiler_options . pop ( "height" , None )
242
- width = compiler_options . pop ( "width" , None )
243
- if height is None :
244
- height = self . config . vision_config . image_size
245
- if width is None :
246
- width = self . config . vision_config . image_size
241
+
242
+ if img_size is None and hasattr ( self . config . vision_config , "image_size" ):
243
+ img_size = getattr ( self . config . vision_config , "image_size" )
244
+ elif img_size is None :
245
+ img_size = 1540 # FIXME based on Mistral3 Image size
246
+ logger . warning ( "Setting img_size to be 1540, as it was neither passed nor found in vision_config" )
247
247
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
248
248
ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
249
249
patch_size = self .config .vision_config .patch_size
250
250
kernel_size = self .config .spatial_merge_size
251
- vision_size = ((height // patch_size ) * (width // patch_size )) * (batch_size ) // (kernel_size * kernel_size )
251
+ vision_size = ((img_size // patch_size ) * (img_size // patch_size )) * (batch_size ) // (kernel_size * kernel_size )
252
252
253
253
vision = [
254
254
{
255
255
"batch_size" : batch_size ,
256
256
"seq_len" : prefill_seq_len ,
257
257
"ctx_len" : ctx_len ,
258
- "height" : height ,
259
- "width" : width ,
258
+ "image_size" : img_size ,
260
259
"vision_size" : vision_size ,
261
260
}
262
261
]
@@ -265,16 +264,14 @@ def get_specializations(
265
264
"batch_size" : batch_size ,
266
265
"seq_len" : prefill_seq_len ,
267
266
"ctx_len" : ctx_len ,
268
- "height" : height ,
269
- "width" : width ,
267
+ "image_size" : img_size ,
270
268
"vision_size" : vision_size ,
271
269
},
272
270
{
273
271
"batch_size" : batch_size ,
274
272
"seq_len" : "1" ,
275
273
"ctx_len" : ctx_len ,
276
- "height" : height ,
277
- "width" : width ,
274
+ "image_size" : img_size ,
278
275
"vision_size" : vision_size ,
279
276
},
280
277
]
@@ -296,7 +293,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
296
293
num_layers = self .config .text_config .num_hidden_layers
297
294
298
295
vision_dynamic_axes = {
299
- "pixel_values" : {0 : "batch_size" , 2 : "height " , 3 : "width " },
296
+ "pixel_values" : {0 : "batch_size" , 2 : "image_size " , 3 : "image_size " },
300
297
}
301
298
lang_dynamic_axes = {
302
299
"input_ids" : {0 : "batch_size" , 1 : "seq_len" },
@@ -341,5 +338,5 @@ def get_inputs_info(self):
341
338
return [
342
339
IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
343
340
IOInfo (name = "attention_mask" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
344
- IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("batch_size" , 3 , "height " , "width " )),
341
+ IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("batch_size" , 3 , "image_size " , "image_size " )),
345
342
]
0 commit comments