Skip to content

Commit e7389de

Browse files
Revert "OnnxToTorch bicubic interpolation (#3802)"
This reverts commit 889a836.
1 parent 612ccc3 commit e7389de

File tree

3 files changed

+31
-292
lines changed

3 files changed

+31
-292
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29222922
llvm::SmallVector<Value> operands;
29232923
std::string mode, nearest_mode, coordTfMode;
29242924
int64_t antialias, exclude_outside;
2925-
float extrapolation_value, cubic_coeff_a;
2925+
float extrapolation_value;
29262926
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
29272927

29282928
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
@@ -2947,8 +2947,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29472947
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
29482948
0.0) ||
29492949
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
2950-
"round_prefer_floor") ||
2951-
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
2950+
"round_prefer_floor"))
29522951
return failure();
29532952
if (antialias != 0) {
29542953
return rewriter.notifyMatchFailure(
@@ -2977,11 +2976,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29772976
"except asymmetric and half_pixel");
29782977
}
29792978

2980-
if (mode == "cubic" && cubic_coeff_a != -0.75) {
2981-
return rewriter.notifyMatchFailure(
2982-
binder.op, "unimplemented: cubic coeff must be -0.75");
2983-
}
2984-
29852979
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
29862980
.getSizes()
29872981
.size();
@@ -2997,11 +2991,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29972991
Value alignCorners =
29982992
coordTfMode == "align_corners" ? cstTrue : cstFalse;
29992993
if (mode == "cubic") {
3000-
std::string modeStr = "cubic";
3001-
if (coordTfMode != "half_pixel")
3002-
modeStr = modeStr + "_" + coordTfMode;
3003-
modeStrValue =
3004-
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
2994+
return rewriter.notifyMatchFailure(binder.op,
2995+
"unimplemented: bicubic mode");
30052996
}
30062997
// supported modes:
30072998
// bilinear (half_pixel), bilinear with align_corners,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 25 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,7 +2683,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
26832683
};
26842684
} // namespace
26852685

2686-
static Value nearestInterpolate(OpBuilder &b, Location loc,
2686+
static Value NearestInterpolate(OpBuilder &b, Location loc,
26872687
SmallVector<Value> outputSizes, Value input,
26882688
SmallVector<Value> inputSizes,
26892689
SmallVector<Value> scaleValues,
@@ -2771,12 +2771,12 @@ static Value nearestInterpolate(OpBuilder &b, Location loc,
27712771
return retVal;
27722772
}
27732773

2774-
static SmallVector<Value> coordinateTransform(
2775-
OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc,
2776-
SmallVector<Value> outputSizes, Value input, SmallVector<Value> inputSizes,
2777-
SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool,
2778-
SmallVector<Value> indices, bool clip) {
2779-
2774+
static Value BilinearInterpolate(OpBuilder &b,
2775+
Aten__InterpolateSizeListScaleListOp op,
2776+
Location loc, SmallVector<Value> outputSizes,
2777+
Value input, SmallVector<Value> inputSizes,
2778+
SmallVector<Value> scaleValues,
2779+
std::string coordStr) {
27802780
unsigned dimOffset = 2;
27812781
auto inputType = cast<RankedTensorType>(input.getType());
27822782
auto inputRank = inputType.getRank();
@@ -2785,7 +2785,15 @@ static SmallVector<Value> coordinateTransform(
27852785
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
27862786
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
27872787

2788-
SmallVector<Value> proj;
2788+
bool alignCornersBool;
2789+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2790+
2791+
SmallVector<Value> indices;
2792+
for (unsigned i = 0; i < inputRank; i++) {
2793+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2794+
}
2795+
2796+
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
27892797
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
27902798
// length_original
27912799
Value inputFP =
@@ -2848,50 +2856,13 @@ static SmallVector<Value> coordinateTransform(
28482856
outputSizeFP, cstOneFloat);
28492857
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
28502858
}
2851-
if (clip) {
2852-
// preClip is the fp position inside the input image to extract from.
2853-
// clip to [0,inf)
2854-
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
2855-
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
2856-
// clip to [0,length_original - 1].
2857-
// proj is properly within the input image.
2858-
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
2859-
} else {
2860-
proj.push_back(preClip);
2861-
}
2862-
}
2863-
return proj;
2864-
}
2865-
2866-
static Value bilinearInterpolate(OpBuilder &b,
2867-
Aten__InterpolateSizeListScaleListOp op,
2868-
Location loc, SmallVector<Value> outputSizes,
2869-
Value input, SmallVector<Value> inputSizes,
2870-
SmallVector<Value> scaleValues,
2871-
std::string coordStr) {
2872-
unsigned dimOffset = 2;
2873-
auto inputType = cast<RankedTensorType>(input.getType());
2874-
auto inputRank = inputType.getRank();
2875-
2876-
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2877-
2878-
bool alignCornersBool;
2879-
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2880-
2881-
SmallVector<Value> indices;
2882-
for (unsigned i = 0; i < inputRank; i++) {
2883-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2884-
}
2885-
2886-
SmallVector<Value> proj, high, low, highFP, lowFP;
2887-
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
2888-
scaleValues, coordStr, alignCornersBool, indices,
2889-
true);
2890-
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
2891-
// length_original
2892-
Value inputFP =
2893-
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
2859+
// preClip is the fp position inside the input image to extract from.
2860+
// clip to [0,inf)
2861+
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
28942862
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
2863+
// clip to [0,length_original - 1].
2864+
// proj is properly within the input image.
2865+
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
28952866

28962867
// for bilinear interpolation, we look for the nearest indices below and
28972868
// above proj
@@ -2955,176 +2926,6 @@ static Value bilinearInterpolate(OpBuilder &b,
29552926
return b.create<arith::AddFOp>(loc, left, right);
29562927
}
29572928

2958-
static Value bicubicInterpolate(OpBuilder &b,
2959-
Aten__InterpolateSizeListScaleListOp op,
2960-
Location loc, SmallVector<Value> outputSizes,
2961-
Value input, SmallVector<Value> inputSizes,
2962-
SmallVector<Value> scaleValues,
2963-
std::string coordStr) {
2964-
unsigned dimOffset = 2;
2965-
auto inputType = cast<RankedTensorType>(input.getType());
2966-
auto inputRank = inputType.getRank();
2967-
2968-
Value inputFPH =
2969-
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
2970-
Value inputFPW =
2971-
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
2972-
2973-
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
2974-
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
2975-
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2976-
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
2977-
Value cstThreeFloat =
2978-
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
2979-
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
2980-
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
2981-
Value cstEightFloat =
2982-
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
2983-
2984-
// (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
2985-
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
2986-
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
2987-
Value xDistanceCubed =
2988-
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
2989-
2990-
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
2991-
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2992-
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
2993-
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2994-
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2995-
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2996-
2997-
return lessEqualOne;
2998-
};
2999-
3000-
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
3001-
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
3002-
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
3003-
Value xDistanceCubed =
3004-
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
3005-
// a|x|^3
3006-
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
3007-
3008-
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
3009-
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
3010-
// a|x|^3 - 5a|x|^2
3011-
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
3012-
3013-
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
3014-
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
3015-
// a|x|^3 - 5a|x|^2 + 8a|x|
3016-
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
3017-
3018-
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
3019-
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a
3020-
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
3021-
return lessThanTwo;
3022-
};
3023-
3024-
bool alignCornersBool;
3025-
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
3026-
3027-
SmallVector<Value> indices;
3028-
for (unsigned i = 0; i < inputRank; i++) {
3029-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
3030-
}
3031-
3032-
SmallVector<Value> proj;
3033-
3034-
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
3035-
scaleValues, coordStr, alignCornersBool, indices,
3036-
false);
3037-
3038-
// get the nearest neighbors of proj
3039-
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
3040-
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
3041-
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
3042-
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
3043-
3044-
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
3045-
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
3046-
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
3047-
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
3048-
3049-
// calculate the distance of nearest neighbors x and y to proj
3050-
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
3051-
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
3052-
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
3053-
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
3054-
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
3055-
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
3056-
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
3057-
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
3058-
3059-
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
3060-
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
3061-
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
3062-
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
3063-
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
3064-
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
3065-
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
3066-
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
3067-
3068-
SmallVector<Value> y{y_2, y_1, y1, y2};
3069-
SmallVector<Value> x{x_2, x_1, x1, x2};
3070-
3071-
SmallVector<Value> wys{
3072-
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
3073-
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
3074-
SmallVector<Value> wxs{
3075-
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
3076-
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};
3077-
3078-
// clip the nearest neighbors points to inside the original image
3079-
for (int k = 0; k < 4; k++) {
3080-
Value yClipped = b.create<arith::MaximumFOp>(loc, y[k], zero);
3081-
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
3082-
yClipped = b.create<arith::MinimumFOp>(loc, yClipped, inputHSubOne);
3083-
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yClipped);
3084-
y[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
3085-
3086-
Value xClipped = b.create<arith::MaximumFOp>(loc, x[k], zero);
3087-
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
3088-
xClipped = b.create<arith::MinimumFOp>(loc, xClipped, inputWSubOne);
3089-
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xClipped);
3090-
x[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
3091-
}
3092-
// 1. Compute x_original and y_original (proj)
3093-
// 2. Compute nearest x and y neighbors
3094-
// 3. Compute Wx Wy
3095-
// 4. Extract inputs at nearest neighbors (inputExtracts)
3096-
// 5. Compute weighted sum (yield this)
3097-
3098-
// 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
3099-
// 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
3100-
// Sum_x is over 4 nearest x neighbors (similar for Sum_y)
3101-
// f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
3102-
// * W(y_original - y)
3103-
Value fxy = zero;
3104-
3105-
for (int j = 0; j < 4; j++) {
3106-
Value wy = wys[j];
3107-
Value xInterpy = zero;
3108-
3109-
indices[dimOffset] = y[j];
3110-
3111-
for (int i = 0; i < 4; i++) {
3112-
Value wx = wxs[i];
3113-
3114-
indices[dimOffset + 1] = x[i];
3115-
3116-
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
3117-
3118-
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
3119-
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
3120-
}
3121-
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
3122-
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
3123-
}
3124-
3125-
return fxy;
3126-
}
3127-
31282929
namespace {
31292930
class ConvertInterpolateOp
31302931
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -3140,8 +2941,7 @@ class ConvertInterpolateOp
31402941
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
31412942
// op with the non-standard mode="bilinear_asymmetric".
31422943
matchPattern(op.getMode(), m_TorchConstantStr(mode));
3143-
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
3144-
mode.substr(0, 5) != "cubic") {
2944+
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
31452945
return failure();
31462946
}
31472947

@@ -3223,18 +3023,13 @@ class ConvertInterpolateOp
32233023
(mode.find(",") == std::string::npos)
32243024
? ""
32253025
: mode.substr(mode.find(",") + 1);
3226-
retVal = nearestInterpolate(
3026+
retVal = NearestInterpolate(
32273027
b, loc, outputSizeIntValues, input, inputSizes,
32283028
ScaleFactorFloatValues, coordTfMode, nearestMode);
32293029
} else if (mode.substr(0, 8) == "bilinear") {
3230-
retVal = bilinearInterpolate(
3030+
retVal = BilinearInterpolate(
32313031
b, op, loc, outputSizeIntValues, input, inputSizes,
32323032
ScaleFactorFloatValues, mode.substr(8));
3233-
} else if (mode.substr(0, 5) == "cubic") {
3234-
3235-
retVal = bicubicInterpolate(
3236-
b, op, loc, outputSizeIntValues, input, inputSizes,
3237-
ScaleFactorFloatValues, mode.substr(5));
32383033
}
32393034
b.create<linalg::YieldOp>(loc, retVal);
32403035
})

0 commit comments

Comments
 (0)