-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave
)
#2162
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Tensor.repeat_interleave
)Tensor.repeat_interleave
)
Edit: Fixed ✅ (current failing tests unrelated) |
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. | ||
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GREAT ! Thanks for omitting this !
@@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs): | |||
# contains the output names of the model. In the case of Timm classification models, the output | |||
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config | |||
# match the outputs in order. | |||
filterd_outputs = {} | |||
filtered_outputs = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch ! I'm embarrassed by the amount of times I've modified this file without seeing this x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR does the following:
torch.Tensor.repeat_interleave
, which are unable to export due to a bug in pytorch:torch.onnx.export
(dynamo=False) fails with uninformative error when exportingapply_rotary_pos_emb
/repeat_interleave
pytorch/pytorch#145100. Note that this bug most likely won't be fixed as the pytorch team transitions to the new dynamo-based exporter.Fixes # (issue)
Before submitting
Who can review?