Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave) #2162

Merged
merged 13 commits into from
Feb 17, 2025

Conversation

xenova
Copy link
Contributor

@xenova xenova commented Jan 18, 2025

What does this PR do?

This PR does the following:

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova xenova changed the title Add support for Moonshine ONNX export (& models with non-legacy cache & Tensor.repeat_interleave) Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave) Feb 11, 2025
@xenova xenova marked this pull request as ready for review February 11, 2025 21:14
@xenova
Copy link
Contributor Author

xenova commented Feb 12, 2025

Seeing some failures for SAM exports regarding the repeat_interleave op. Looking into it now.

Edit: Fixed ✅ (current failing tests unrelated)

Comment on lines +201 to +202
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GREAT ! Thanks for omitting this !

@@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filterd_outputs = {}
filtered_outputs = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch ! I'm embarrassed by the amount of times I've modified this file without seeing this x)

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xenova xenova merged commit 414afab into main Feb 17, 2025
35 of 40 checks passed
@xenova xenova deleted the add-moonshine-onnx branch February 17, 2025 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants