Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ONNX): avoids resizing unsupported dimensions #3945

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c0e47ff
refactor(ONNX): chains mutually-exclusive guard and operand usage in …
bjacobgordon Jan 28, 2025
bf60cfa
refactor(ONNX): narrows operand count comparison in onnx.resize
bjacobgordon Jan 28, 2025
82f6445
refactor(ONNX): extracts `numberOfOperands` within onnx.resize
bjacobgordon Jan 28, 2025
50c2529
refactor(ONNX): enforces min assignment-usage distance for value list…
bjacobgordon Jan 14, 2025
e4f4425
refactor(ONNX): removes redundant nulling assignment in onnx.resize
bjacobgordon Jan 14, 2025
bd23b93
refactor(ONNX): enforces min assignment-usage distance for `noneVal` …
bjacobgordon Jan 8, 2025
87f9f54
refactor(ONNX): extracts `loc` within onnx.resize
bjacobgordon Jan 9, 2025
874827d
refactor(ONNX): moves `rank` closer to first usage in onnx.resize
bjacobgordon Jan 10, 2025
1d77eb9
refactor(ONNX): forces cast of operand in onnx.resize
bjacobgordon Jan 8, 2025
e41fa62
refactor(ONNX): loosens downcast in onnx.resize
bjacobgordon Jan 15, 2025
140b628
refactor(ONNX): extracts `inputTensor` within onnx.resize
bjacobgordon Jan 8, 2025
fff08f2
refactor(ONNX): extracts `inputTensorType` from rank derivation in on…
bjacobgordon Jan 8, 2025
1f7cdf0
refactor(ONNX): extracts `sizesOfInputTensor` from rank derivation in…
bjacobgordon Jan 8, 2025
e835f1a
refactor(ONNX): uses `auto` annotation for `rank` in onnx.resize
bjacobgordon Jan 27, 2025
a858e45
refactor(ONNX): renames `rank` to `rankOfInputTensor` in onnx.resize
bjacobgordon Jan 10, 2025
3f41467
refactor(ONNX): renames `resultType` to `outputTensorType` in onnx.re…
bjacobgordon Jan 8, 2025
948a53e
refactor(ONNX): renames `sizesValueList` to `supportedSizes` in onnx.…
bjacobgordon Jan 15, 2025
fd20a79
refactor(ONNX): renames `scalesValueList` to `supportedScaleFactors` …
bjacobgordon Jan 15, 2025
b897a34
refactor(ONNX): renames `scaleOperand` to `proposedScaleFactors` in o…
bjacobgordon Jan 14, 2025
01e2274
refactor(ONNX): renames `sizeOperand` to `proposedSizes` in onnx.resize
bjacobgordon Jan 14, 2025
266a820
refactor(ONNX): prefers multiline attributes in onnx.resize tests
bjacobgordon Jan 22, 2025
c9f2197
refactor(ONNX): distills checks in lit tests for onnx.resize
bjacobgordon Jan 22, 2025
6dc3fdf
fix(ONNX): differentiates names of lit tests for onnx.resize
bjacobgordon Jan 22, 2025
a230084
fix(ONNX): avoids resizing unsupported dimensions
bjacobgordon Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 119 additions & 38 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,12 +2700,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
});
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Torch::ValueTensorType outputTensorType;
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
int64_t antialias, exclude_outside;
float extrapolation_value, cubic_coeff_a;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
return rewriter.notifyMatchFailure(
Expand All @@ -2720,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}

if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.tensorResultType(outputTensorType) ||
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
Expand All @@ -2732,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

int64_t const /* */ batchDim = 0;
int64_t const /**/ channelDim = 1;

SmallVector<int64_t> nonResizableDims{
batchDim,
channelDim,
};

Value inputTensor = operands[0];
auto inputTensorType =
cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = inputTensorType.getSizes();
auto sizesOfOutputTensor = outputTensorType.getSizes();

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto &eachDim : nonResizableDims) {
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim];
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim];

if (eachSizeOfInputTensor == unknownSize ||
eachSizeOfOutputTensor == unknownSize)
continue;
if (eachSizeOfInputTensor == eachSizeOfOutputTensor)
continue;

auto resizingIntentErrorMessage =
"unsupported: non-trivial intent to resize dimension: " +
std::to_string(eachDim);

return rewriter.notifyMatchFailure(binder.op,
resizingIntentErrorMessage);
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2764,35 +2799,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();
auto loc = binder.getLoc();

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value modeStrValue;

Value scalesValueList = noneVal;
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
std::string modeStr = "cubic";
if (coordTfMode != "half_pixel")
modeStr = modeStr + "_" + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

auto rankOfInputTensor = sizesOfInputTensor.size();

// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
// (asymmetric), nearest with align_corners, nearest_half_pixel,
// nearest_pytorch_half_pixel
if (mode == "linear") {
std::string modeStr;
switch (rank) {
switch (rankOfInputTensor) {
case 3:
modeStr = "linear";
break;
Expand All @@ -2809,8 +2840,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// mode is apparently half_pixel, NOT pytorch_half_pixel
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}
if (mode == "nearest") {
std::string modeStr = "nearest";
Expand All @@ -2820,33 +2850,84 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStr = (modeStr + "_") + coordTfMode;
if (nearest_mode != "floor" && nearest_mode != "")
modeStr = modeStr + "," + nearest_mode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

int64_t assumedForemostSpatialDim = 2;
auto numberOfOperands = operands.size();

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList =
createScalarSublist(binder.getLoc(), scaleOperand,
assumedForemostSpatialDim, rewriter);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList =
createScalarSublist(binder.getLoc(), sizeOperand,
assumedForemostSpatialDim, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
Type boolType = rewriter.getType<Torch::BoolType>();

int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back();

Value supportedScaleFactors;
Value supportedSizes;

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);

if (numberOfOperands == 3) {
Value proposedScaleFactors = operands[2];

Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));
Comment on lines +2870 to +2871
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this appears to have caused some test failures downstream in the IREE project on iree-org/iree#19976. I did not bisect to this specific change or line of code, but this looked most relevant. These are the logs: https://github.com/iree-org/iree/actions/runs/13292751088/job/37117771168?pr=19976#step:8:50

 _ IREE compile and run: test_resize_downsample_scales_cubic_align_corners::model.mlir::model.mlir::cpu_llvm_sync _
[gw2] linux -- Python 3.11.11 /home/runner/work/iree/iree/venv/bin/python
Error invoking iree-compile
Error code: 1
Stderr diagnostics:
<unknown>:0: error: failed to legalize operation 'torch.constant.float'
<unknown>:0: note: see current operation: %6 = "torch.constant.float"() <{value = 1.000000e+00 : f64}> : () -> !torch.float


Stdout diagnostics:


Test case source:
  https://github.com/iree-org/iree-test-suites/blob/main/onnx_ops/onnx/node/generated/test_resize_downsample_scales_cubic_align_corners

Input program:
```
module {
  func.func @test_resize_downsample_scales_cubic_align_corners(%arg0: !torch.vtensor<[1,1,4,4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Resize"(%arg0, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "align_corners", torch.onnx.mode = "cubic"} : (!torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,3,3],f32> 
    return %0 : !torch.vtensor<[1,1,3,3],f32>
  }
}

```

Compiled with:
  cd /home/runner/work/iree/iree/iree-test-suites/onnx_ops/onnx/node/generated/test_resize_downsample_scales_cubic_align_corners && iree-compile model.mlir --iree-hal-target-backends=llvm-cpu --iree-input-demote-f64-to-f32=false -o model_cpu_llvm_sync.vmfb

By default, IREE demotes f64 to f32 as 64 bits of precision is rarely needed in ML models and many hardware targets either do not support f64 at all or support it with significant performance penalties. The tests there do override that default by setting --iree-input-demote-f64-to-f32=false though.

Is f64 needed here, or would f32 work? I see lots of uses of f64 in this file 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More context: some of the tests in the ONNX test suite require f64, which is why we run the tests without f64 to f32 demotion: iree-org/iree#18111.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need f64, this is a small bug with the changes. Will post a quick fix in a minute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly when writing this, using f32 for scaleIdentity caused a test case or two within torch mlir to fail.

@zjgarvey Any insights here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, F64 is the correct attr type for constant float ops. I'll take a look at the test failures.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a simple issue. AtenEqFloatOp doesn't have a lowering, but it should be easy to add. I'll post a PR shortly.

Copy link
Collaborator

@zjgarvey zjgarvey Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-ran the iree tests with #4022

The failing tests go back to passing with that change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!


// run-time scale factor check for dynamic sizes
for (auto &eachDim : nonResizableDims) {
Value eachProposedScaleFactor = extractTorchScalar(
loc, eachDim, proposedScaleFactors, rewriter);

Value eachScaleFactorIsIdentity =
rewriter.create<Torch::AtenEqFloatOp>(
loc, boolType, eachProposedScaleFactor, scaleIdentity);

auto errorMessageForEachDim =
"Unsupported: non-trivial scale factor for dimension " +
std::to_string(eachDim);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachScaleFactorIsIdentity,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedScaleFactors = createScalarSublist(
loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter);
supportedSizes = noneVal;
} else if (numberOfOperands == 4) {
Value proposedSizes = operands[3];

// run-time target size check for dynamic sizes
for (auto &eachDimAsInt : nonResizableDims) {
Value eachDimAsValue =
rewriter.create<Torch::ConstantIntOp>(loc, eachDimAsInt);

Value eachSizeOfInputTensor = rewriter.create<Torch::AtenSizeIntOp>(
loc, inputTensor, eachDimAsValue);

Value eachProposedSize =
extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter);

Value eachProposedSizeIsTrivial =
rewriter.create<Torch::AtenEqIntOp>(
loc, boolType, eachProposedSize, eachSizeOfInputTensor);

auto errorMessageForEachDim =
"Unsupported: non-trivial resizing of dimension " +
std::to_string(eachDimAsInt);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachProposedSizeIsTrivial,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedScaleFactors = noneVal;
supportedSizes = createScalarSublist(
loc, proposedSizes, assumedForemostSpatialDim, rewriter);
} else
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}

rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
binder.op, outputTensorType, inputTensor, supportedSizes,
supportedScaleFactors, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
Expand Down
31 changes: 22 additions & 9 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2256,21 +2256,30 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[MODE_STR:.*]] = torch.constant.str "nearest"
// CHECK: torch.aten.__interpolate.size_list_scale_list
// CHECK-SAME: %[[MODE_STR]]
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "asymmetric",
torch.onnx.cubic_coeff_a = -7.500000e-01 : f32,
torch.onnx.mode = "nearest",
torch.onnx.nearest_mode = "floor"
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_resize_sizes_nearest_half_pixel
func.func @test_resize_sizes_nearest_half_pixel(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list
// CHECK-SAME: %[[MODE_STR]]
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
torch.onnx.mode = "nearest"
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

Expand All @@ -2280,8 +2289,12 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[MODE_STR:.*]] = torch.constant.str "bilinear"
// CHECK: torch.aten.__interpolate.size_list_scale_list
// CHECK-SAME: %[[MODE_STR]]
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.mode = "linear"
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

Expand Down
Loading