-
Notifications
You must be signed in to change notification settings - Fork 500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave
)
#2162
Merged
Merged
Changes from 12 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3269458
Add moonshine ONNX config
xenova 17604ac
Remove use_cache_position for whisper exports
xenova 03c4816
Merge branch 'main' into add-moonshine-onnx
xenova adf8aba
Merge branch 'main' into add-moonshine-onnx
xenova 83348a0
Patch torch repeat_interleave during export
xenova e668d89
Add support for exporting models with non-legacy caches
xenova 121ed80
Formatting
xenova fc6913f
Merge branch 'main' into add-moonshine-onnx
xenova abbe4c5
Re-use model patcher for seq2seq models
xenova b270476
Add moonshine unit tests
xenova 8373720
Formatting
xenova c69f6b4
When tracing, repeats passed as an int will be turned into a tensor o…
xenova ade05e7
Fix failing unit test on 4.45.1 CI. Confirmed it works above 4.46 too.
xenova File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): | |
return result | ||
|
||
|
||
UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] | ||
# An ONNX-export-compatible version of `tensor.repeat_interleave`. | ||
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100 | ||
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly). | ||
# and can be removed once Optimum switches to dynamo-based exports | ||
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None): | ||
""" | ||
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave. | ||
|
||
Args: | ||
input_tensor (torch.Tensor): The input tensor. | ||
repeats (int or torch.Tensor): The number of repetitions for each element. | ||
dim (int, optional): The dimension along which to repeat. Defaults to None. | ||
|
||
Returns: | ||
torch.Tensor: The repeated tensor. | ||
""" | ||
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0): | ||
if dim is None: | ||
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten() | ||
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device) | ||
|
||
if dim is None: | ||
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0) | ||
|
||
if dim != 0: | ||
input_tensor = input_tensor.transpose(0, dim) | ||
|
||
# Create expand mask | ||
max_repeats = repeats.max() | ||
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:]) | ||
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1) | ||
result = expanded[mask] | ||
|
||
if dim != 0: | ||
result = result.transpose(0, dim) | ||
|
||
return result | ||
|
||
|
||
UNSUPPORTED_OPS_PATCHING_SPEC = [ | ||
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold), | ||
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave), | ||
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. | ||
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__), | ||
] | ||
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)] | ||
|
||
|
||
|
@@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs): | |
# contains the output names of the model. In the case of Timm classification models, the output | ||
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config | ||
# match the outputs in order. | ||
filterd_outputs = {} | ||
filtered_outputs = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch ! I'm embarrassed by the amount of times I've modified this file without seeing this x) |
||
if isinstance(outputs, dict): | ||
for name, value in outputs.items(): | ||
onnx_output_name = config.torch_to_onnx_output_map.get(name, name) | ||
|
@@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs): | |
or (allow_past_in_outputs and name.startswith("past_key_values")) | ||
or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) | ||
): | ||
filterd_outputs[name] = value | ||
filtered_outputs[name] = value | ||
elif isinstance(outputs, (list, tuple)): | ||
outputs_list = list(config.outputs.keys()) | ||
filterd_outputs = dict(zip(outputs_list, outputs)) | ||
filtered_outputs = dict(zip(outputs_list, outputs)) | ||
else: | ||
if len(config.outputs) > 1: | ||
num_outputs = len(config.outputs) | ||
|
@@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs): | |
) | ||
else: | ||
name = list(config.outputs.keys())[0] | ||
filterd_outputs[name] = outputs | ||
filtered_outputs[name] = outputs | ||
name = list(config.outputs.keys())[0] | ||
filterd_outputs[name] = outputs | ||
filtered_outputs[name] = outputs | ||
|
||
if is_transformers_version(">=", "4.48"): | ||
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): | ||
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() | ||
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): | ||
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() | ||
|
||
return filterd_outputs | ||
return filtered_outputs | ||
|
||
self.patched_forward = patched_forward | ||
|
||
|
@@ -325,15 +369,18 @@ def __init__( | |
if model.config.model_type == "pix2struct" and allow_past_in_outputs: | ||
model.config.text_config.use_cache = True | ||
|
||
@functools.wraps(self.orig_forward) | ||
# Re-use the patched forward method from the parent class | ||
self.super_patched_forward = self.patched_forward | ||
|
||
@functools.wraps(self.super_patched_forward) | ||
def patched_forward(*args, **kwargs): | ||
signature = inspect.signature(self.orig_forward) | ||
signature = inspect.signature(self.super_patched_forward) | ||
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) | ||
|
||
outputs = self.orig_forward(*args, **kwargs) | ||
outputs = self.super_patched_forward(*args, **kwargs) | ||
|
||
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants. | ||
filterd_outputs = {} | ||
filtered_outputs = {} | ||
for name, value in outputs.items(): | ||
onnx_output_name = config.torch_to_onnx_output_map.get(name, name) | ||
if ( | ||
|
@@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs): | |
# Who cares about the encoder outputs in the decoder? | ||
continue | ||
else: | ||
filterd_outputs[name] = value | ||
filtered_outputs[name] = value | ||
else: | ||
if self.real_config._behavior == "monolith" or ( | ||
self.real_config._behavior == "decoder" | ||
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs) | ||
): | ||
filterd_outputs[name] = value | ||
filtered_outputs[name] = value | ||
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: | ||
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. | ||
filterd_outputs[name] = tuple([v[:2] for v in value]) | ||
return filterd_outputs | ||
filtered_outputs[name] = tuple([v[:2] for v in value]) | ||
return filtered_outputs | ||
|
||
self.patched_forward = patched_forward | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GREAT ! Thanks for omitting this !