-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(ONNX): avoids resizing unsupported dimensions #3945
fix(ONNX): avoids resizing unsupported dimensions #3945
Conversation
6baa8d5
to
ab7e021
Compare
ab7e021
to
7aec80b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main structural question is about the need for adding the BaseTensorType
method. If it were useful elsewhere (I have some doubts, since we would need to know too much about the two tensor shapes prior to using it- namely that they are present, and they have the same rank), I would consider keeping it; however, the code is simplified here by not using it, and I suspect that the same would be true in other circumstances where it might be used.
7aec80b
to
a20ee29
Compare
a20ee29
to
574f4fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the misdirect earlier, we need to perform the runtime asserts on the scales or sizes values instead of sizes, since we will not have access to the correct output sizes ahead of time.
574f4fe
to
cb371ea
Compare
I think the renaming of |
05ea165
to
cb20894
Compare
8a78147
to
54cf76e
Compare
Okay, @zjgarvey, I think we're in business! Got green on the CI a few hours ago. Just wrapped up self-review. I'm guessing now we'll:
I kept the commits atomic and ordered such that it's easier to propagate changes to the head of the branch in case an earlier commit needs to be inserted/tweaked/excised. Let me know what you think! |
Nice! At least for now, please exclude any commits which involve style changes not directly related to the fix content. E.g. widespread enforcement of naming preferences like the |
54cf76e
to
052404a
Compare
…in onnx.resize - avoids SSA before match failures
- cast to `ValueTensorType` was overly specific for the methods used
- intellisense is able to infer `unsigned` aspect from `.size()`
…size - emphasizes parallel to `inputTensorType`
- easier to read - allows for cleaner diffs if they ever change
a824f38
to
a230084
Compare
Hey, @zjgarvey, I've rebased onto To make it easier for you to re-review, I could segment this PR into a stack of 3 PRs. It'll be really easy; the commits are already primed for it! Would you like me to do that for you? |
No, this change looks manageable and self-contained. I'll review a bit more carefully today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>( | ||
loc, rewriter.getF64FloatAttr(1.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI this appears to have caused some test failures downstream in the IREE project on iree-org/iree#19976. I did not bisect to this specific change or line of code, but this looked most relevant. These are the logs: https://github.com/iree-org/iree/actions/runs/13292751088/job/37117771168?pr=19976#step:8:50
_ IREE compile and run: test_resize_downsample_scales_cubic_align_corners::model.mlir::model.mlir::cpu_llvm_sync _
[gw2] linux -- Python 3.11.11 /home/runner/work/iree/iree/venv/bin/python
Error invoking iree-compile
Error code: 1
Stderr diagnostics:
<unknown>:0: error: failed to legalize operation 'torch.constant.float'
<unknown>:0: note: see current operation: %6 = "torch.constant.float"() <{value = 1.000000e+00 : f64}> : () -> !torch.float
Stdout diagnostics:
Test case source:
https://github.com/iree-org/iree-test-suites/blob/main/onnx_ops/onnx/node/generated/test_resize_downsample_scales_cubic_align_corners
Input program:
```
module {
func.func @test_resize_downsample_scales_cubic_align_corners(%arg0: !torch.vtensor<[1,1,4,4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.Resize"(%arg0, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "align_corners", torch.onnx.mode = "cubic"} : (!torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,3,3],f32>
return %0 : !torch.vtensor<[1,1,3,3],f32>
}
}
```
Compiled with:
cd /home/runner/work/iree/iree/iree-test-suites/onnx_ops/onnx/node/generated/test_resize_downsample_scales_cubic_align_corners && iree-compile model.mlir --iree-hal-target-backends=llvm-cpu --iree-input-demote-f64-to-f32=false -o model_cpu_llvm_sync.vmfb
By default, IREE demotes f64 to f32 as 64 bits of precision is rarely needed in ML models and many hardware targets either do not support f64 at all or support it with significant performance penalties. The tests there do override that default by setting --iree-input-demote-f64-to-f32=false
though.
Is f64 needed here, or would f32 work? I see lots of uses of f64
in this file 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More context: some of the tests in the ONNX test suite require f64, which is why we run the tests without f64 to f32 demotion: iree-org/iree#18111.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need f64, this is a small bug with the changes. Will post a quick fix in a minute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly when writing this, using f32 for scaleIdentity
caused a test case or two within torch mlir to fail.
@zjgarvey Any insights here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, F64 is the correct attr type for constant float ops. I'll take a look at the test failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a simple issue. AtenEqFloatOp
doesn't have a lowering, but it should be easy to add. I'll post a PR shortly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-ran the iree tests with #4022
The failing tests go back to passing with that change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
Addresses an issue introduced by <#3945> in an external test suite.
Prevents #3453