Skip to content

Commit 75b51c5

Browse files
committed
Revert "Minor Changes"
This reverts commit fed3a56.
1 parent fed3a56 commit 75b51c5

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

QEfficient/transformers/models/mistral3/modeling_mistral3.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from QEfficient.utils import constants
1818
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
19-
from QEfficient.utils.logging_utils import logger
19+
2020

2121
def custom_cumsum(tensor):
2222
dim = 0
@@ -238,24 +238,25 @@ def get_specializations(
238238
kv_offload: bool = False,
239239
**compiler_options,
240240
):
241-
242-
if img_size is None and hasattr(self.config.vision_config, "image_size"):
243-
img_size = getattr(self.config.vision_config, "image_size")
244-
elif img_size is None:
245-
img_size = 1540 # FIXME based on Mistral3 Image size
246-
logger.warning("Setting img_size to be 1540, as it was neither passed nor found in vision_config")
241+
height = compiler_options.pop("height", None)
242+
width = compiler_options.pop("width", None)
243+
if height is None:
244+
height = self.config.vision_config.image_size
245+
if width is None:
246+
width = self.config.vision_config.image_size
247247
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
248248
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
249249
patch_size = self.config.vision_config.patch_size
250250
kernel_size = self.config.spatial_merge_size
251-
vision_size = ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
251+
vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size)
252252

253253
vision = [
254254
{
255255
"batch_size": batch_size,
256256
"seq_len": prefill_seq_len,
257257
"ctx_len": ctx_len,
258-
"image_size": img_size,
258+
"height": height,
259+
"width": width,
259260
"vision_size": vision_size,
260261
}
261262
]
@@ -264,14 +265,16 @@ def get_specializations(
264265
"batch_size": batch_size,
265266
"seq_len": prefill_seq_len,
266267
"ctx_len": ctx_len,
267-
"image_size": img_size,
268+
"height": height,
269+
"width": width,
268270
"vision_size": vision_size,
269271
},
270272
{
271273
"batch_size": batch_size,
272274
"seq_len": "1",
273275
"ctx_len": ctx_len,
274-
"image_size": img_size,
276+
"height": height,
277+
"width": width,
275278
"vision_size": vision_size,
276279
},
277280
]
@@ -293,7 +296,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
293296
num_layers = self.config.text_config.num_hidden_layers
294297

295298
vision_dynamic_axes = {
296-
"pixel_values": {0: "batch_size", 2: "image_size", 3: "image_size"},
299+
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
297300
}
298301
lang_dynamic_axes = {
299302
"input_ids": {0: "batch_size", 1: "seq_len"},
@@ -338,5 +341,5 @@ def get_inputs_info(self):
338341
return [
339342
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
340343
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
341-
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")),
344+
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")),
342345
]

0 commit comments

Comments
 (0)