Skip to content

Commit cb371ea

Browse files
committed
fix(ONNX): avoids resizing non scalable dimensions
1 parent 9b766e2 commit cb371ea

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,8 +2753,66 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27532753
binder.op, "unimplemented: cubic coeff must be -0.75");
27542754
}
27552755

2756+
Value inputTensor = operands[0];
2757+
Torch::ValueTensorType typeOfInputTensor =
2758+
cast<Torch::ValueTensorType>(inputTensor.getType());
2759+
2760+
ArrayRef<int64_t> sizesOfInputTensor = typeOfInputTensor.getSizes();
2761+
ArrayRef<int64_t> sizesOfOutputTensor = typeOfOutputTensor.getSizes();
2762+
2763+
int64_t const dimensionAssumedToBeBatch = 0;
2764+
int64_t const dimensionAssumedToBeChannel = 1;
2765+
int64_t nonScalableDimensions[] = {
2766+
dimensionAssumedToBeBatch,
2767+
dimensionAssumedToBeChannel,
2768+
};
2769+
2770+
auto unknownSize = Torch::kUnknownSize;
2771+
2772+
// Compile-time check for dimensions of static size
2773+
for (auto eachDimension : nonScalableDimensions) {
2774+
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension];
2775+
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension];
2776+
2777+
if (eachSizeOfInputTensor == unknownSize ||
2778+
eachSizeOfOutputTensor == unknownSize) {
2779+
continue;
2780+
} else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) {
2781+
continue;
2782+
}
2783+
2784+
auto scalingIntentErrorMessage =
2785+
"unsupported: non-trivial intent to scale dimension: " +
2786+
std::to_string(eachDimension);
2787+
2788+
return rewriter.notifyMatchFailure(binder.op,
2789+
scalingIntentErrorMessage);
2790+
};
2791+
27562792
auto opLocation = binder.getLoc();
27572793

2794+
// Run-time check for dimensions of dynamic size
2795+
for (auto eachDimension : nonScalableDimensions) {
2796+
auto eachDimensionAsValue = rewriter.create<Torch::ConstantIntOp>(
2797+
opLocation, rewriter.getI64IntegerAttr(eachDimension));
2798+
2799+
Value eachSizeOfInputAsValue = rewriter.create<Torch::AtenSizeIntOp>(
2800+
opLocation, inputTensor, eachDimensionAsValue);
2801+
2802+
int64_t eachSizeOfOutput = sizesOfOutputTensor[eachDimension];
2803+
Value eachSizeOfOutputAsValue = rewriter.create<Torch::ConstantIntOp>(
2804+
opLocation, rewriter.getI64IntegerAttr(eachSizeOfOutput));
2805+
2806+
Value eachSizeComparison = rewriter.create<Torch::AtenEqIntOp>(
2807+
opLocation, eachSizeOfInputAsValue, eachSizeOfOutputAsValue);
2808+
2809+
rewriter.create<Torch::RuntimeAssertOp>(
2810+
opLocation, eachSizeComparison,
2811+
rewriter.getStringAttr(
2812+
"unsupported: non-trivial scaling of dimension " +
2813+
std::to_string(eachDimension)));
2814+
};
2815+
27582816
Value cstFalse =
27592817
rewriter.create<Torch::ConstantBoolOp>(opLocation, false);
27602818
Value cstTrue =
@@ -2774,10 +2832,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27742832
rewriter.create<Torch::ConstantStrOp>(opLocation, modeStr);
27752833
}
27762834

2777-
Value inputTensor = operands[0];
2778-
Torch::ValueTensorType typeOfInputTensor =
2779-
cast<Torch::ValueTensorType>(inputTensor.getType());
2780-
ArrayRef<int64_t> sizesOfInputTensor = typeOfInputTensor.getSizes();
27812835
unsigned rankOfInputTensor = sizesOfInputTensor.size();
27822836

27832837
// supported modes:

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22562256
// CHECK-LABEL: func.func @test_resize_sizes_nearest
22572257
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22582258
%none = torch.constant.none
2259-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2259+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22602260
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22612261
return %0 : !torch.vtensor<[?,?,?,?],f32>
22622262
}
@@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22672267
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22682268
%none = torch.constant.none
22692269
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2270-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2270+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %[[STR]], %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22712271
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
22722272
torch.onnx.coordinate_transformation_mode = "half_pixel",
22732273
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
@@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
22802280
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
22812281
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22822282
%none = torch.constant.none
2283-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2283+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22842284
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22852285
return %0 : !torch.vtensor<[?,?,?,?],f32>
22862286
}

0 commit comments

Comments
 (0)