Skip to content

Commit fed3a56

Browse files
committed
Minor Changes
Signed-off-by: Mohit Soni <[email protected]>
1 parent d12fe8c commit fed3a56

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

QEfficient/transformers/models/mistral3/modeling_mistral3.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from QEfficient.utils import constants
1818
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
19-
19+
from QEfficient.utils.logging_utils import logger
2020

2121
def custom_cumsum(tensor):
2222
dim = 0
@@ -238,25 +238,24 @@ def get_specializations(
238238
kv_offload: bool = False,
239239
**compiler_options,
240240
):
241-
height = compiler_options.pop("height", None)
242-
width = compiler_options.pop("width", None)
243-
if height is None:
244-
height = self.config.vision_config.image_size
245-
if width is None:
246-
width = self.config.vision_config.image_size
241+
242+
if img_size is None and hasattr(self.config.vision_config, "image_size"):
243+
img_size = getattr(self.config.vision_config, "image_size")
244+
elif img_size is None:
245+
img_size = 1540 # FIXME based on Mistral3 Image size
246+
logger.warning("Setting img_size to be 1540, as it was neither passed nor found in vision_config")
247247
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
248248
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
249249
patch_size = self.config.vision_config.patch_size
250250
kernel_size = self.config.spatial_merge_size
251-
vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size)
251+
vision_size = ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
252252

253253
vision = [
254254
{
255255
"batch_size": batch_size,
256256
"seq_len": prefill_seq_len,
257257
"ctx_len": ctx_len,
258-
"height": height,
259-
"width": width,
258+
"image_size": img_size,
260259
"vision_size": vision_size,
261260
}
262261
]
@@ -265,16 +264,14 @@ def get_specializations(
265264
"batch_size": batch_size,
266265
"seq_len": prefill_seq_len,
267266
"ctx_len": ctx_len,
268-
"height": height,
269-
"width": width,
267+
"image_size": img_size,
270268
"vision_size": vision_size,
271269
},
272270
{
273271
"batch_size": batch_size,
274272
"seq_len": "1",
275273
"ctx_len": ctx_len,
276-
"height": height,
277-
"width": width,
274+
"image_size": img_size,
278275
"vision_size": vision_size,
279276
},
280277
]
@@ -296,7 +293,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
296293
num_layers = self.config.text_config.num_hidden_layers
297294

298295
vision_dynamic_axes = {
299-
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
296+
"pixel_values": {0: "batch_size", 2: "image_size", 3: "image_size"},
300297
}
301298
lang_dynamic_axes = {
302299
"input_ids": {0: "batch_size", 1: "seq_len"},
@@ -341,5 +338,5 @@ def get_inputs_info(self):
341338
return [
342339
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
343340
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
344-
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")),
341+
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")),
345342
]

0 commit comments

Comments
 (0)