-
Notifications
You must be signed in to change notification settings - Fork 612
[TOSA] Add transposed conv support #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TOSA] Add transposed conv support #4360
Conversation
sahas3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good to me. Some minor comments and clarifying questions.
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping. Change-Id: I709579e40a1ccaf9b9188392c7c78fcb653109ce
f6e53b6 to
6db8cd6
Compare
sahas3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG once you fix the issue about failure check after IR rewrites.
For future reference, it'll be good if you preserve the commits since it's easier to review what's changed since last time review was provided. We can squash commits during merge. Thanks!
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: grouped transposed convolution not supported by " | ||
| "TOSA"); | ||
| if (dilation[0] != 1 || dilation[1] != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: dilated transposed convolution not supported by " | ||
| "TOSA"); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these notify failures need to happen before any IR rewrites take place, otherwise the pattern rewriter ends up in a recursive loop. For example, on line 2410 we'd have already introduced tosa.transpose and then we'll bail out from here for a failure case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed for transpose and depthwise paths.
Lazily create the NHWC input transpose so we emit it only once the failure guards in the transposed and depthwise convolution rewrite succeed. Change-Id: Ia362deda898794397107f6da3c44cd89f219f58f
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping.