Skip to content

Commit 7aec80b

Browse files
committed
fix(ONNX): avoids resizing conventionally fixed dimensions
1 parent 76368bd commit 7aec80b

File tree

4 files changed

+66
-15
lines changed

4 files changed

+66
-15
lines changed

include/torch-mlir/Dialect/Torch/IR/TorchTypes.h

+4
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class BaseTensorType : public Type {
8484
/// Enable isa/dyn_cast for BaseTensorType.
8585
static bool classof(Type type);
8686

87+
/// The element-wise comparison of each dimension/size in `that` tensor
88+
std::vector<std::optional<bool>>
89+
shapeComparisonAgainst(BaseTensorType that) const;
90+
8791
/// Return true if this type has the same sizes and dtype as the other.
8892
bool hasSameSizesAndDtype(BaseTensorType other) const;
8993

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27172717
"round_prefer_floor") ||
27182718
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
27192719
return failure();
2720+
2721+
Value inputTensor = operands[0];
2722+
Torch::ValueTensorType inputTensor_blueprint =
2723+
cast<Torch::ValueTensorType>(inputTensor.getType());
2724+
2725+
std::vector<std::optional<bool>> shapeComparison =
2726+
inputTensor_blueprint.shapeComparisonAgainst(
2727+
outputTensor_blueprint);
2728+
2729+
// Comparisons of the dimensions assumed to carry the batch and channel
2730+
auto shapeComparisonForFixedDimensions =
2731+
ArrayRef(shapeComparison).take_front(2);
2732+
2733+
for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) {
2734+
if (eachDimensionComparison == std::nullopt) {
2735+
return rewriter.notifyMatchFailure(
2736+
binder.op, "Sizes for batch and channel dimensions must be "
2737+
"statically defined");
2738+
}
2739+
if (eachDimensionComparison == false) {
2740+
return rewriter.notifyMatchFailure(
2741+
binder.op,
2742+
"Unexpected intent to resize the batch/channel dimensions");
2743+
}
2744+
};
2745+
27202746
if (antialias != 0) {
27212747
return rewriter.notifyMatchFailure(
27222748
binder.op,
@@ -2749,9 +2775,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27492775
binder.op, "unimplemented: cubic coeff must be -0.75");
27502776
}
27512777

2752-
Value inputTensor = operands[0];
2753-
Torch::ValueTensorType inputTensor_blueprint =
2754-
cast<Torch::ValueTensorType>(inputTensor.getType());
27552778
ArrayRef<int64_t> inputTensor_dimensions =
27562779
inputTensor_blueprint.getSizes();
27572780
unsigned rank = inputTensor_dimensions.size();

lib/Dialect/Torch/IR/TorchTypes.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,30 @@ static bool isValidTorchDtype(Type dtype) {
217217
return false;
218218
}
219219

220+
std::vector<std::optional<bool>>
221+
BaseTensorType::shapeComparisonAgainst(BaseTensorType that) const {
222+
auto this_dimensions = /**/ getSizes();
223+
auto that_dimensions = that.getSizes();
224+
225+
auto this_rank = this_dimensions.size();
226+
auto that_rank = that_dimensions.size();
227+
228+
assert((this_rank == that_rank) && "Ranks must match to compare dimensions");
229+
230+
std::vector<std::optional<bool>> runningComparison = {};
231+
auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions);
232+
233+
for (auto [eachLHDimension, eachRHDimension] : dimensionPairs) {
234+
if (eachLHDimension == kUnknownSize || eachRHDimension == kUnknownSize) {
235+
runningComparison.push_back(std::nullopt);
236+
} else {
237+
runningComparison.push_back(eachLHDimension == eachRHDimension);
238+
}
239+
}
240+
241+
return runningComparison;
242+
}
243+
220244
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
221245
return getOptionalSizes() == other.getOptionalSizes() &&
222246
getOptionalDtype() == other.getOptionalDtype();

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+12-12
Original file line numberDiff line numberDiff line change
@@ -2254,35 +2254,35 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22542254
// -----
22552255

22562256
// CHECK-LABEL: func.func @test_resize_sizes_nearest
2257-
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2257+
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22582258
%none = torch.constant.none
2259-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2260-
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2261-
return %0 : !torch.vtensor<[?,?,?,?],f32>
2259+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
2260+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
2261+
return %0 : !torch.vtensor<[1,1,?,?],f32>
22622262
}
22632263

22642264
// -----
22652265

22662266
// CHECK-LABEL: func.func @test_resize_sizes_nearest
2267-
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2267+
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22682268
%none = torch.constant.none
22692269
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2270-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2270+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
22712271
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
22722272
torch.onnx.coordinate_transformation_mode = "half_pixel",
2273-
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2274-
return %0 : !torch.vtensor<[?,?,?,?],f32>
2273+
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
2274+
return %0 : !torch.vtensor<[1,1,?,?],f32>
22752275
}
22762276

22772277
// -----
22782278

22792279
// CHECK-LABEL: func.func @test_resize_sizes_linear
2280-
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
2280+
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],
22812281
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22822282
%none = torch.constant.none
2283-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2284-
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2285-
return %0 : !torch.vtensor<[?,?,?,?],f32>
2283+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
2284+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
2285+
return %0 : !torch.vtensor<[1,1,?,?],f32>
22862286
}
22872287

22882288
// -----

0 commit comments

Comments
 (0)