Skip to content

Commit ab7e021

Browse files
committed
fix(ONNX): avoids resizing conventionally fixed dimensions
1 parent ab0858c commit ab7e021

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

include/torch-mlir/Dialect/Torch/IR/TorchTypes.h

+4
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class BaseTensorType : public Type {
8484
/// Enable isa/dyn_cast for BaseTensorType.
8585
static bool classof(Type type);
8686

87+
/// The element-wise comparison of each dimension/size in `that` tensor
88+
ArrayRef<std::optional<bool>>
89+
dimensionComparisonsAgainst(BaseTensorType that) const;
90+
8791
/// Return true if this type has the same sizes and dtype as the other.
8892
bool hasSameSizesAndDtype(BaseTensorType other) const;
8993

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27172717
"round_prefer_floor") ||
27182718
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
27192719
return failure();
2720+
2721+
Value inputTensor = operands[0];
2722+
Torch::ValueTensorType inputTensor_shape =
2723+
cast<Torch::ValueTensorType>(inputTensor.getType());
2724+
2725+
ArrayRef<std::optional<bool>> dimensionComparisons =
2726+
inputTensor_shape.dimensionComparisonsAgainst(outputTensor_shape);
2727+
2728+
// Comparisons of the dimensions assumed to carry the batch and channel
2729+
auto fixedDimensionComparisons = dimensionComparisons.take_front(2);
2730+
2731+
for (auto eachComparison : fixedDimensionComparisons) {
2732+
if (eachComparison == nullptr) {
2733+
return rewriter.notifyMatchFailure(
2734+
binder.op, "Sizes for batch and channel dimensions must be "
2735+
"statically defined");
2736+
}
2737+
if (eachComparison == false) {
2738+
return rewriter.notifyMatchFailure(
2739+
binder.op,
2740+
"Unexpected intent to resize the batch/channel dimensions");
2741+
}
2742+
};
2743+
27202744
if (antialias != 0) {
27212745
return rewriter.notifyMatchFailure(
27222746
binder.op,
@@ -2749,9 +2773,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27492773
binder.op, "unimplemented: cubic coeff must be -0.75");
27502774
}
27512775

2752-
Value inputTensor = operands[0];
2753-
Torch::ValueTensorType inputTensor_shape =
2754-
cast<Torch::ValueTensorType>(inputTensor.getType());
27552776
ArrayRef<int64_t> inputTensor_dimensions = inputTensor_shape.getSizes();
27562777
unsigned rank = inputTensor_dimensions.size();
27572778

lib/Dialect/Torch/IR/TorchTypes.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,31 @@ static bool isValidTorchDtype(Type dtype) {
217217
return false;
218218
}
219219

220+
ArrayRef<std::optional<bool>>
221+
BaseTensorType::dimensionComparisonsAgainst(BaseTensorType that) const {
222+
auto this_dimensions = /**/ getSizes();
223+
auto that_dimensions = that.getSizes();
224+
225+
auto this_rank = this_dimensions.size();
226+
auto that_rank = that_dimensions.size();
227+
assert((this_rank == that_rank) && "Ranks must match to compare dimensions");
228+
229+
SmallVector<std::optional<bool>> dimensionComparisons;
230+
dimensionComparisons.reserve(this_rank);
231+
232+
auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions);
233+
234+
for (auto [eachPair_lhs, eachPair_rhs] : dimensionPairs) {
235+
if (eachPair_lhs == kUnknownSize || eachPair_rhs == kUnknownSize) {
236+
dimensionComparisons.push_back(nullptr);
237+
} else {
238+
dimensionComparisons.push_back(eachPair_lhs == eachPair_rhs);
239+
}
240+
}
241+
242+
return dimensionComparisons;
243+
}
244+
220245
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
221246
return getOptionalSizes() == other.getOptionalSizes() &&
222247
getOptionalDtype() == other.getOptionalDtype();

0 commit comments

Comments
 (0)