Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ONNX export of CLIP models with sdpa #2066

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MgpstrModelPatcher,
MistralModelPatcher,
Expand Down Expand Up @@ -1109,6 +1108,7 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):

class CLIPVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1122,16 +1122,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1150,13 +1144,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
Expand Down Expand Up @@ -1202,13 +1189,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
Expand Down
13 changes: 0 additions & 13 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,16 +1278,3 @@ def __init__(
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask


class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.43"):
self.original_sdpa_forward = CLIPSdpaAttention.forward
CLIPSdpaAttention.forward = CLIPAttention.forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.43"):
CLIPSdpaAttention.forward = self.original_sdpa_forward
Loading