Skip to content

Commit 574f4fe

Browse files
committed
fix(ONNX): avoids resizing non scalable dimensions
1 parent 17029e6 commit 574f4fe

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+61-4
Original file line numberDiff line numberDiff line change
@@ -2749,8 +2749,69 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27492749
binder.op, "unimplemented: cubic coeff must be -0.75");
27502750
}
27512751

2752+
Value inputTensor = operands[0];
2753+
Torch::ValueTensorType inputTensor_type =
2754+
cast<Torch::ValueTensorType>(inputTensor.getType());
2755+
ArrayRef<int64_t> inputTensor_sizes = inputTensor_type.getSizes();
2756+
ArrayRef<int64_t> outputTensor_sizes = outputTensor_type.getSizes();
2757+
2758+
int64_t const batchDimension = 0;
2759+
int64_t const channelDimension = 1;
2760+
int64_t nonScalableDimensions[] = {
2761+
batchDimension,
2762+
channelDimension,
2763+
};
2764+
2765+
auto errorMessageForScaling = [](int64_t givenDimension) {
2766+
switch (givenDimension) {
2767+
case batchDimension:
2768+
return "Unexpected intent to scale the batch dimension";
2769+
case channelDimension:
2770+
return "Unexpected intent to scale the channel dimension";
2771+
default:
2772+
return "Scalable dimension treated as non-scalable";
2773+
}
2774+
};
2775+
2776+
auto unknownSize = Torch::kUnknownSize;
2777+
2778+
// Compile-time check for dimensions of static size
2779+
for (auto eachDimension : nonScalableDimensions) {
2780+
auto eachInputSize = inputTensor_sizes[eachDimension];
2781+
auto eachOutputSize = outputTensor_sizes[eachDimension];
2782+
2783+
if (eachInputSize == unknownSize || eachOutputSize == unknownSize) {
2784+
continue;
2785+
} else if (eachInputSize == eachOutputSize) {
2786+
continue;
2787+
}
2788+
2789+
return rewriter.notifyMatchFailure(
2790+
binder.op, errorMessageForScaling(eachDimension));
2791+
}
2792+
27522793
auto binderLocation = binder.getLoc();
27532794

2795+
// Run-time check for dimensions of dynamic size
2796+
for (auto eachDimension : nonScalableDimensions) {
2797+
auto eachDimensionAsValue = rewriter.create<Torch::ConstantIntOp>(
2798+
binderLocation, rewriter.getI64IntegerAttr(eachDimension));
2799+
2800+
Value eachInputSizeAsValue = rewriter.create<Torch::AtenSizeIntOp>(
2801+
binderLocation, inputTensor, eachDimensionAsValue);
2802+
2803+
int64_t eachOutputSize = outputTensor_sizes[eachDimension];
2804+
Value eachOutputSizeAsValue = rewriter.create<Torch::ConstantIntOp>(
2805+
binderLocation, rewriter.getI64IntegerAttr(eachOutputSize));
2806+
2807+
Value eachSizeComparison = rewriter.create<Torch::AtenEqIntOp>(
2808+
binderLocation, eachInputSizeAsValue, eachOutputSizeAsValue);
2809+
2810+
rewriter.create<Torch::RuntimeAssertOp>(
2811+
binderLocation, eachSizeComparison,
2812+
rewriter.getStringAttr(errorMessageForScaling(eachDimension)));
2813+
};
2814+
27542815
Value cstFalse =
27552816
rewriter.create<Torch::ConstantBoolOp>(binderLocation, false);
27562817
Value cstTrue =
@@ -2770,10 +2831,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27702831
rewriter.create<Torch::ConstantStrOp>(binderLocation, modeStr);
27712832
}
27722833

2773-
Value inputTensor = operands[0];
2774-
Torch::ValueTensorType inputTensor_type =
2775-
cast<Torch::ValueTensorType>(inputTensor.getType());
2776-
ArrayRef<int64_t> inputTensor_sizes = inputTensor_type.getSizes();
27772834
unsigned inputTensor_rank = inputTensor_sizes.size();
27782835

27792836
// supported modes:

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22562256
// CHECK-LABEL: func.func @test_resize_sizes_nearest
22572257
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22582258
%none = torch.constant.none
2259-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2259+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22602260
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22612261
return %0 : !torch.vtensor<[?,?,?,?],f32>
22622262
}
@@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22672267
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22682268
%none = torch.constant.none
22692269
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2270-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2270+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %[[STR]], %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22712271
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
22722272
torch.onnx.coordinate_transformation_mode = "half_pixel",
22732273
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
@@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
22802280
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
22812281
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22822282
%none = torch.constant.none
2283-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2283+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22842284
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22852285
return %0 : !torch.vtensor<[?,?,?,?],f32>
22862286
}

0 commit comments

Comments
 (0)