Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave) #2162

Merged
merged 13 commits into from
Feb 17, 2025
Merged
32 changes: 27 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"input_features": {0: "batch_size", 1: "sequence_classification"}}


class MoonshineOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

return common_inputs


class WhisperOnnxConfig(AudioToTextOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

Expand All @@ -1802,11 +1829,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if is_transformers_version(">=", "4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs
Expand Down
81 changes: 64 additions & 17 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
# An ONNX-export-compatible version of `tensor.repeat_interleave`.
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
# and can be removed once Optimum switches to dynamo-based exports
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.

Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.

Returns:
torch.Tensor: The repeated tensor.
"""
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
if dim is None:
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)

if dim is None:
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)

if dim != 0:
input_tensor = input_tensor.transpose(0, dim)

# Create expand mask
max_repeats = repeats.max()
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
result = expanded[mask]

if dim != 0:
result = result.transpose(0, dim)

return result


UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
Comment on lines +201 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GREAT ! Thanks for omitting this !

]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]


Expand Down Expand Up @@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filterd_outputs = {}
filtered_outputs = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch ! I'm embarrassed by the amount of times I've modified this file without seeing this x)

if isinstance(outputs, dict):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
Expand All @@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs):
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filterd_outputs = dict(zip(outputs_list, outputs))
filtered_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
Expand All @@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs):
)
else:
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs

if is_transformers_version(">=", "4.48"):
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

return filterd_outputs
return filtered_outputs

self.patched_forward = patched_forward

Expand Down Expand Up @@ -325,15 +369,18 @@ def __init__(
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
model.config.text_config.use_cache = True

@functools.wraps(self.orig_forward)
# Re-use the patched forward method from the parent class
self.super_patched_forward = self.patched_forward

@functools.wraps(self.super_patched_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
signature = inspect.signature(self.super_patched_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

outputs = self.orig_forward(*args, **kwargs)
outputs = self.super_patched_forward(*args, **kwargs)

# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
filtered_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
Expand All @@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs):
# Who cares about the encoder outputs in the decoder?
continue
else:
filterd_outputs[name] = value
filtered_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])
return filterd_outputs
filtered_outputs[name] = tuple([v[:2] for v in value])
return filtered_outputs

self.patched_forward = patched_forward

Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,13 @@ class TasksManager:
"token-classification",
onnx="ModernBertOnnxConfig",
),
"moonshine": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
onnx="MoonshineOnnxConfig",
),
"mpnet": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"mobilenet-v1": "hf-internal-testing/tiny-random-MobileNetV1Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
"moonshine": "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "lewtun/tiny-random-mt5",
Expand Down Expand Up @@ -271,6 +272,7 @@
"mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"modernbert": "answerdotai/ModernBERT-base",
"moonshine": "UsefulSensors/moonshine-tiny",
"mpt": "mosaicml/mpt-7b",
"mt5": "google/mt5-small",
"musicgen": "facebook/musicgen-small",
Expand Down
Loading