Skip to content

Commit bdedce9

Browse files
Fix TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS (#1501)
* Fix TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS * remove quantizer set task * add test with transformers v4.46 * temporary fix * fix * style * remove 4.46 tests * remove video-text-to-text task * only change task if has remote code * add comment * Update optimum/intel/openvino/quantization.py Co-authored-by: Nikita Savelyev <[email protected]> --------- Co-authored-by: Nikita Savelyev <[email protected]>
1 parent e8c230b commit bdedce9

File tree

3 files changed

+24
-45
lines changed

3 files changed

+24
-45
lines changed

optimum/exporters/openvino/__main__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,16 @@ class StoreAttr(object):
394394
if library_name == "open_clip":
395395
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
396396
else:
397+
# remote code models like phi3_v internvl2, minicpmv, internvl2, nanollava, maira2 should be loaded using AutoModelForCausalLM and not AutoModelForImageTextToText
398+
# TODO: use config.auto_map to load remote code models instead (for other models we can directly use config.architectures)
399+
task_model_loading = task
400+
if library_name == "transformers":
401+
has_remote_code = hasattr(config, "auto_map")
402+
if has_remote_code and trust_remote_code and task == "image-text-to-text":
403+
task_model_loading = "text-generation"
404+
397405
model = TasksManager.get_model_from_task(
398-
task,
406+
task_model_loading,
399407
model_name_or_path,
400408
subfolder=subfolder,
401409
revision=revision,

optimum/exporters/openvino/model_configs.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,6 @@ def init_model_configs():
202202
"AutoModelForImageTextToText",
203203
)
204204

205-
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
206-
"image-text-to-text"
207-
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
208-
209-
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["video-text-to-text"] = "AutoModelForVision2Seq"
210-
211205
if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
212206
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
213207
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
@@ -1698,9 +1692,7 @@ class LlavaNextVideoConfigBehavior(str, enum.Enum):
16981692
TEXT_EMBEDDINGS = "text_embeddings"
16991693

17001694

1701-
@register_in_tasks_manager(
1702-
"llava_next_video", *["image-text-to-text", "video-text-to-text"], library_name="transformers"
1703-
)
1695+
@register_in_tasks_manager("llava_next_video", *["image-text-to-text"], library_name="transformers")
17041696
class LlavaNextVideoOpenVINOConfig(LlavaOpenVINOConfig):
17051697
MIN_TRANSFORMERS_VERSION = "4.42.0"
17061698
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaNextVideoConfigBehavior]
@@ -3301,11 +3293,7 @@ class Qwen2VLConfigBehavior(str, enum.Enum):
33013293
TEXT_EMBEDDINGS = "text_embeddings"
33023294

33033295

3304-
@register_in_tasks_manager(
3305-
"qwen2_vl",
3306-
*["image-text-to-text", "video-text-to-text"],
3307-
library_name="transformers",
3308-
)
3296+
@register_in_tasks_manager("qwen2_vl", *["image-text-to-text"], library_name="transformers")
33093297
class Qwen2VLOpenVINOConfig(BaseVLMOpenVINOConfig):
33103298
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
33113299
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
@@ -3436,11 +3424,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
34363424
return {}
34373425

34383426

3439-
@register_in_tasks_manager(
3440-
"qwen2_5_vl",
3441-
*["image-text-to-text", "video-text-to-text"],
3442-
library_name="transformers",
3443-
)
3427+
@register_in_tasks_manager("qwen2_5_vl", *["image-text-to-text"], library_name="transformers")
34443428
class Qwen2_5_VLOpenVINOConfig(Qwen2VLOpenVINOConfig):
34453429
MIN_TRANSFORMERS_VERSION = "4.49.0"
34463430

@@ -3784,7 +3768,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
37843768
return super().generate(input_name, framework, int_dtype, float_dtype)
37853769

37863770

3787-
@register_in_tasks_manager("idefics3", *["image-text-to-text", "video-text-to-text"], library_name="transformers")
3771+
@register_in_tasks_manager("idefics3", *["image-text-to-text"], library_name="transformers")
37883772
class Idefics3OpenVINOConfig(BaseVLMOpenVINOConfig):
37893773
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionPositionIdsInputGenerator)
37903774
MIN_TRANSFORMERS_VERSION = "4.46.0"
@@ -3843,7 +3827,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior])
38433827
return text_embedding
38443828

38453829

3846-
@register_in_tasks_manager("smolvlm", *["image-text-to-text", "video-text-to-text"], library_name="transformers")
3830+
@register_in_tasks_manager("smolvlm", *["image-text-to-text"], library_name="transformers")
38473831
class SmolVLMOpenVINOConfig(Idefics3OpenVINOConfig):
38483832
MIN_TRANSFORMERS_VERSION = "4.50.0"
38493833

optimum/intel/openvino/quantization.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@
4040
from transformers.pytorch_utils import Conv1D
4141
from transformers.utils import is_accelerate_available
4242

43-
from optimum.exporters.tasks import TasksManager
4443
from optimum.quantization_base import OptimumQuantizer
4544
from optimum.utils.logging import warn_once
4645

47-
from ..utils.constant import _TASK_ALIASES
4846
from ..utils.import_utils import (
4947
DATASETS_IMPORT_ERROR,
5048
_nncf_version,
@@ -1142,15 +1140,20 @@ def __init__(self, model: OVModel, task: Optional[str] = None, seed: int = 42, *
11421140
Args:
11431141
model (`OVModel`):
11441142
The [OVModel](https://huggingface.co/docs/optimum-intel/en/openvino/reference) to quantize.
1145-
task (`str`, defaults to None):
1146-
The task defining the model topology used for the ONNX export.
11471143
seed (`int`, defaults to 42):
11481144
The random seed to use when shuffling the calibration dataset.
11491145
"""
11501146
super().__init__()
11511147
self.model = model
1152-
self.task = task
11531148
self.dataset_builder = OVCalibrationDatasetBuilder(model, seed)
1149+
self._task = task
1150+
if self._task is not None:
1151+
logger.warning("The `task` argument is ignored and will be removed in optimum-intel v1.27")
1152+
1153+
@property
1154+
def task(self) -> Dict[str, Union[openvino.Model, openvino.runtime.CompiledModel]]:
1155+
logger.warning("The `task` attribute is deprecated and will be removed in v1.27.")
1156+
return self._task
11541157

11551158
@classmethod
11561159
def from_pretrained(cls, model: OVModel, **kwargs):
@@ -1196,7 +1199,7 @@ def quantize(
11961199
>>> from optimum.intel import OVQuantizer, OVModelForCausalLM
11971200
>>> from transformers import AutoModelForCausalLM
11981201
>>> model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b")
1199-
>>> quantizer = OVQuantizer.from_pretrained(model, task="text-generation")
1202+
>>> quantizer = OVQuantizer.from_pretrained(model)
12001203
>>> ov_config = OVConfig(quantization_config=OVWeightQuantizationConfig())
12011204
>>> quantizer.quantize(ov_config=ov_config, save_directory="./quantized_model")
12021205
>>> optimized_model = OVModelForCausalLM.from_pretrained("./quantized_model")
@@ -1208,7 +1211,7 @@ def quantize(
12081211
>>> model = OVModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True)
12091212
>>> # or
12101213
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
1211-
>>> quantizer = OVQuantizer.from_pretrained(model, task="text-classification")
1214+
>>> quantizer = OVQuantizer.from_pretrained(model)
12121215
>>> ov_config = OVConfig(quantization_config=OVQuantizationConfig())
12131216
>>> quantizer.quantize(calibration_dataset=dataset, ov_config=ov_config, save_directory="./quantized_model")
12141217
>>> optimized_model = OVModelForSequenceClassification.from_pretrained("./quantized_model")
@@ -1454,22 +1457,6 @@ def _save_pretrained(model: openvino.Model, output_path: str):
14541457
compress_quantize_weights_transformation(model)
14551458
openvino.save_model(model, output_path, compress_to_fp16=False)
14561459

1457-
def _set_task(self):
1458-
if self.task is None:
1459-
self.task = TasksManager.infer_task_from_model(self.model.config._name_or_path)
1460-
if self.task is None:
1461-
raise ValueError(
1462-
"The task defining the model topology could not be extracted and needs to be specified for the ONNX export."
1463-
)
1464-
1465-
self.task = _TASK_ALIASES.get(self.task, self.task)
1466-
1467-
if self.task == "text2text-generation":
1468-
raise ValueError("Seq2Seq models are currently not supported for post-training static quantization.")
1469-
1470-
if self.task == "image-to-text":
1471-
raise ValueError("Image2Text models are currently not supported for post-training static quantization.")
1472-
14731460
def get_calibration_dataset(
14741461
self,
14751462
dataset_name: str,

0 commit comments

Comments
 (0)