Skip to content

Commit

Permalink
Add support for Moonshine ONNX export (& seq2seq models with non-lega…
Browse files Browse the repository at this point in the history
…cy cache & `Tensor.repeat_interleave`) (#2162)

* Add moonshine ONNX config

* Remove use_cache_position for whisper exports

* Patch torch repeat_interleave during export

* Add support for exporting models with non-legacy caches

* Formatting

* Re-use model patcher for seq2seq models

* Add moonshine unit tests

* Formatting

* When tracing, repeats passed as an int will be turned into a tensor of rank 0.

* Fix failing unit test on 4.45.1 CI. Confirmed it works above 4.46 too.
  • Loading branch information
xenova authored Feb 17, 2025
1 parent 27dae50 commit 414afab
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 20 deletions.
33 changes: 30 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"input_features": {0: "batch_size", 1: "sequence_classification"}}


class MoonshineOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

return common_inputs


class WhisperOnnxConfig(AudioToTextOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

Expand All @@ -1802,9 +1829,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if is_transformers_version(">=", "4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
if is_transformers_version(">=", "4.43.0") and is_transformers_version("<", "4.46.0"):
# since https://github.com/huggingface/transformers/pull/31166
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
Expand Down
81 changes: 64 additions & 17 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
# An ONNX-export-compatible version of `tensor.repeat_interleave`.
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
# and can be removed once Optimum switches to dynamo-based exports
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.
Returns:
torch.Tensor: The repeated tensor.
"""
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
if dim is None:
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)

if dim is None:
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)

if dim != 0:
input_tensor = input_tensor.transpose(0, dim)

# Create expand mask
max_repeats = repeats.max()
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
result = expanded[mask]

if dim != 0:
result = result.transpose(0, dim)

return result


UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]


Expand Down Expand Up @@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filterd_outputs = {}
filtered_outputs = {}
if isinstance(outputs, dict):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
Expand All @@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs):
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filterd_outputs = dict(zip(outputs_list, outputs))
filtered_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
Expand All @@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs):
)
else:
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs

if is_transformers_version(">=", "4.48"):
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

return filterd_outputs
return filtered_outputs

self.patched_forward = patched_forward

Expand Down Expand Up @@ -325,15 +369,18 @@ def __init__(
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
model.config.text_config.use_cache = True

@functools.wraps(self.orig_forward)
# Re-use the patched forward method from the parent class
self.super_patched_forward = self.patched_forward

@functools.wraps(self.super_patched_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
signature = inspect.signature(self.super_patched_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

outputs = self.orig_forward(*args, **kwargs)
outputs = self.super_patched_forward(*args, **kwargs)

# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
filtered_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
Expand All @@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs):
# Who cares about the encoder outputs in the decoder?
continue
else:
filterd_outputs[name] = value
filtered_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])
return filterd_outputs
filtered_outputs[name] = tuple([v[:2] for v in value])
return filtered_outputs

self.patched_forward = patched_forward

Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,13 @@ class TasksManager:
"token-classification",
onnx="ModernBertOnnxConfig",
),
"moonshine": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
onnx="MoonshineOnnxConfig",
),
"mpnet": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"mobilenet-v1": "hf-internal-testing/tiny-random-MobileNetV1Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
"moonshine": "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "lewtun/tiny-random-mt5",
Expand Down Expand Up @@ -271,6 +272,7 @@
"mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"modernbert": "answerdotai/ModernBERT-base",
"moonshine": "UsefulSensors/moonshine-tiny",
"mpt": "mosaicml/mpt-7b",
"mt5": "google/mt5-small",
"musicgen": "facebook/musicgen-small",
Expand Down

0 comments on commit 414afab

Please sign in to comment.