@@ -2683,7 +2683,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2683
2683
};
2684
2684
} // namespace
2685
2685
2686
- static Value nearestInterpolate (OpBuilder &b, Location loc,
2686
+ static Value NearestInterpolate (OpBuilder &b, Location loc,
2687
2687
SmallVector<Value> outputSizes, Value input,
2688
2688
SmallVector<Value> inputSizes,
2689
2689
SmallVector<Value> scaleValues,
@@ -2771,12 +2771,12 @@ static Value nearestInterpolate(OpBuilder &b, Location loc,
2771
2771
return retVal;
2772
2772
}
2773
2773
2774
- static SmallVector< Value> coordinateTransform (
2775
- OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc ,
2776
- SmallVector<Value> outputSizes, Value input , SmallVector<Value> inputSizes ,
2777
- SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool ,
2778
- SmallVector<Value> indices, bool clip) {
2779
-
2774
+ static Value BilinearInterpolate (OpBuilder &b,
2775
+ Aten__InterpolateSizeListScaleListOp op ,
2776
+ Location loc , SmallVector<Value> outputSizes ,
2777
+ Value input, SmallVector<Value> inputSizes ,
2778
+ SmallVector<Value> scaleValues,
2779
+ std::string coordStr) {
2780
2780
unsigned dimOffset = 2 ;
2781
2781
auto inputType = cast<RankedTensorType>(input.getType ());
2782
2782
auto inputRank = inputType.getRank ();
@@ -2785,7 +2785,15 @@ static SmallVector<Value> coordinateTransform(
2785
2785
Value cstHalf = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.5 ));
2786
2786
Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2787
2787
2788
- SmallVector<Value> proj;
2788
+ bool alignCornersBool;
2789
+ matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2790
+
2791
+ SmallVector<Value> indices;
2792
+ for (unsigned i = 0 ; i < inputRank; i++) {
2793
+ indices.push_back (b.create <linalg::IndexOp>(loc, i));
2794
+ }
2795
+
2796
+ SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2789
2797
for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2790
2798
// length_original
2791
2799
Value inputFP =
@@ -2848,50 +2856,13 @@ static SmallVector<Value> coordinateTransform(
2848
2856
outputSizeFP, cstOneFloat);
2849
2857
preClip = b.create <arith::SelectOp>(loc, cmp, zero, preClip);
2850
2858
}
2851
- if (clip) {
2852
- // preClip is the fp position inside the input image to extract from.
2853
- // clip to [0,inf)
2854
- Value max = b.create <arith::MaximumFOp>(loc, preClip, zero);
2855
- Value inputSubOne = b.create <arith::SubFOp>(loc, inputFP, cstOneFloat);
2856
- // clip to [0,length_original - 1].
2857
- // proj is properly within the input image.
2858
- proj.push_back (b.create <arith::MinimumFOp>(loc, max, inputSubOne));
2859
- } else {
2860
- proj.push_back (preClip);
2861
- }
2862
- }
2863
- return proj;
2864
- }
2865
-
2866
- static Value bilinearInterpolate (OpBuilder &b,
2867
- Aten__InterpolateSizeListScaleListOp op,
2868
- Location loc, SmallVector<Value> outputSizes,
2869
- Value input, SmallVector<Value> inputSizes,
2870
- SmallVector<Value> scaleValues,
2871
- std::string coordStr) {
2872
- unsigned dimOffset = 2 ;
2873
- auto inputType = cast<RankedTensorType>(input.getType ());
2874
- auto inputRank = inputType.getRank ();
2875
-
2876
- Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2877
-
2878
- bool alignCornersBool;
2879
- matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
2880
-
2881
- SmallVector<Value> indices;
2882
- for (unsigned i = 0 ; i < inputRank; i++) {
2883
- indices.push_back (b.create <linalg::IndexOp>(loc, i));
2884
- }
2885
-
2886
- SmallVector<Value> proj, high, low, highFP, lowFP;
2887
- proj = coordinateTransform (b, op, loc, outputSizes, input, inputSizes,
2888
- scaleValues, coordStr, alignCornersBool, indices,
2889
- true );
2890
- for (unsigned i = 0 ; i < inputRank - dimOffset; i++) {
2891
- // length_original
2892
- Value inputFP =
2893
- b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[i]);
2859
+ // preClip is the fp position inside the input image to extract from.
2860
+ // clip to [0,inf)
2861
+ Value max = b.create <arith::MaximumFOp>(loc, preClip, zero);
2894
2862
Value inputSubOne = b.create <arith::SubFOp>(loc, inputFP, cstOneFloat);
2863
+ // clip to [0,length_original - 1].
2864
+ // proj is properly within the input image.
2865
+ proj.push_back (b.create <arith::MinimumFOp>(loc, max, inputSubOne));
2895
2866
2896
2867
// for bilinear interpolation, we look for the nearest indices below and
2897
2868
// above proj
@@ -2955,176 +2926,6 @@ static Value bilinearInterpolate(OpBuilder &b,
2955
2926
return b.create <arith::AddFOp>(loc, left, right);
2956
2927
}
2957
2928
2958
- static Value bicubicInterpolate (OpBuilder &b,
2959
- Aten__InterpolateSizeListScaleListOp op,
2960
- Location loc, SmallVector<Value> outputSizes,
2961
- Value input, SmallVector<Value> inputSizes,
2962
- SmallVector<Value> scaleValues,
2963
- std::string coordStr) {
2964
- unsigned dimOffset = 2 ;
2965
- auto inputType = cast<RankedTensorType>(input.getType ());
2966
- auto inputRank = inputType.getRank ();
2967
-
2968
- Value inputFPH =
2969
- b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[0 ]);
2970
- Value inputFPW =
2971
- b.create <arith::SIToFPOp>(loc, b.getF32Type (), inputSizes[1 ]);
2972
-
2973
- Value a = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (-0.75 ));
2974
- Value zero = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
2975
- Value cstOneFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (1.0 ));
2976
- Value cstTwoFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (2.0 ));
2977
- Value cstThreeFloat =
2978
- b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (3.0 ));
2979
- Value cstFourFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (4.0 ));
2980
- Value cstFiveFloat = b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (5.0 ));
2981
- Value cstEightFloat =
2982
- b.create <arith::ConstantOp>(loc, b.getF32FloatAttr (8.0 ));
2983
-
2984
- // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
2985
- auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
2986
- Value xDistanceSquared = b.create <arith::MulFOp>(loc, xDistance, xDistance);
2987
- Value xDistanceCubed =
2988
- b.create <arith::MulFOp>(loc, xDistanceSquared, xDistance);
2989
-
2990
- Value lessEqualOne = b.create <arith::AddFOp>(loc, a, cstTwoFloat);
2991
- lessEqualOne = b.create <arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2992
- Value aPlusThree = b.create <arith::AddFOp>(loc, a, cstThreeFloat);
2993
- aPlusThree = b.create <arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2994
- lessEqualOne = b.create <arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2995
- lessEqualOne = b.create <arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2996
-
2997
- return lessEqualOne;
2998
- };
2999
-
3000
- // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
3001
- auto WeightLessThanTwo = [&](Value xDistance) -> Value {
3002
- Value xDistanceSquared = b.create <arith::MulFOp>(loc, xDistance, xDistance);
3003
- Value xDistanceCubed =
3004
- b.create <arith::MulFOp>(loc, xDistanceSquared, xDistance);
3005
- // a|x|^3
3006
- Value lessThanTwo = b.create <arith::MulFOp>(loc, xDistanceCubed, a);
3007
-
3008
- Value fiveA = b.create <arith::MulFOp>(loc, xDistanceSquared, a);
3009
- fiveA = b.create <arith::MulFOp>(loc, fiveA, cstFiveFloat);
3010
- // a|x|^3 - 5a|x|^2
3011
- lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fiveA);
3012
-
3013
- Value eightA = b.create <arith::MulFOp>(loc, a, xDistance);
3014
- eightA = b.create <arith::MulFOp>(loc, eightA, cstEightFloat);
3015
- // a|x|^3 - 5a|x|^2 + 8a|x|
3016
- lessThanTwo = b.create <arith::AddFOp>(loc, eightA, lessThanTwo);
3017
-
3018
- Value fourA = b.create <arith::MulFOp>(loc, a, cstFourFloat);
3019
- // a|x|^3 - 5a|x|^2 + 8a|x| - 4a
3020
- lessThanTwo = b.create <arith::SubFOp>(loc, lessThanTwo, fourA);
3021
- return lessThanTwo;
3022
- };
3023
-
3024
- bool alignCornersBool;
3025
- matchPattern (op.getAlignCorners (), m_TorchConstantBool (&alignCornersBool));
3026
-
3027
- SmallVector<Value> indices;
3028
- for (unsigned i = 0 ; i < inputRank; i++) {
3029
- indices.push_back (b.create <linalg::IndexOp>(loc, i));
3030
- }
3031
-
3032
- SmallVector<Value> proj;
3033
-
3034
- proj = coordinateTransform (b, op, loc, outputSizes, input, inputSizes,
3035
- scaleValues, coordStr, alignCornersBool, indices,
3036
- false );
3037
-
3038
- // get the nearest neighbors of proj
3039
- Value x1 = b.create <math::CeilOp>(loc, proj[1 ]);
3040
- Value x_1 = b.create <arith::SubFOp>(loc, x1, cstOneFloat);
3041
- Value x_2 = b.create <arith::SubFOp>(loc, x_1, cstOneFloat);
3042
- Value x2 = b.create <arith::AddFOp>(loc, x1, cstOneFloat);
3043
-
3044
- Value y1 = b.create <math::CeilOp>(loc, proj[0 ]);
3045
- Value y_1 = b.create <arith::SubFOp>(loc, y1 , cstOneFloat);
3046
- Value y_2 = b.create <arith::SubFOp>(loc, y_1, cstOneFloat);
3047
- Value y2 = b.create <arith::AddFOp>(loc, y1 , cstOneFloat);
3048
-
3049
- // calculate the distance of nearest neighbors x and y to proj
3050
- Value y2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y2);
3051
- y2Distance = b.create <math::AbsFOp>(loc, y2Distance);
3052
- Value y1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y1 );
3053
- y1Distance = b.create <math::AbsFOp>(loc, y1Distance);
3054
- Value y_1Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_1);
3055
- y_1Distance = b.create <math::AbsFOp>(loc, y_1Distance);
3056
- Value y_2Distance = b.create <arith::SubFOp>(loc, proj[0 ], y_2);
3057
- y_2Distance = b.create <math::AbsFOp>(loc, y_2Distance);
3058
-
3059
- Value x2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x2);
3060
- x2Distance = b.create <math::AbsFOp>(loc, x2Distance);
3061
- Value x1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x1);
3062
- x1Distance = b.create <math::AbsFOp>(loc, x1Distance);
3063
- Value x_1Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_1);
3064
- x_1Distance = b.create <math::AbsFOp>(loc, x_1Distance);
3065
- Value x_2Distance = b.create <arith::SubFOp>(loc, proj[1 ], x_2);
3066
- x_2Distance = b.create <math::AbsFOp>(loc, x_2Distance);
3067
-
3068
- SmallVector<Value> y{y_2, y_1, y1 , y2};
3069
- SmallVector<Value> x{x_2, x_1, x1, x2};
3070
-
3071
- SmallVector<Value> wys{
3072
- WeightLessThanTwo (y_2Distance), WeightLessThanEqualOne (y_1Distance),
3073
- WeightLessThanEqualOne (y1Distance), WeightLessThanTwo (y2Distance)};
3074
- SmallVector<Value> wxs{
3075
- WeightLessThanTwo (x_2Distance), WeightLessThanEqualOne (x_1Distance),
3076
- WeightLessThanEqualOne (x1Distance), WeightLessThanTwo (x2Distance)};
3077
-
3078
- // clip the nearest neighbors points to inside the original image
3079
- for (int k = 0 ; k < 4 ; k++) {
3080
- Value yClipped = b.create <arith::MaximumFOp>(loc, y[k], zero);
3081
- Value inputHSubOne = b.create <arith::SubFOp>(loc, inputFPH, cstOneFloat);
3082
- yClipped = b.create <arith::MinimumFOp>(loc, yClipped, inputHSubOne);
3083
- Value yInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), yClipped);
3084
- y[k] = b.create <arith::IndexCastOp>(loc, b.getIndexType (), yInt);
3085
-
3086
- Value xClipped = b.create <arith::MaximumFOp>(loc, x[k], zero);
3087
- Value inputWSubOne = b.create <arith::SubFOp>(loc, inputFPW, cstOneFloat);
3088
- xClipped = b.create <arith::MinimumFOp>(loc, xClipped, inputWSubOne);
3089
- Value xInt = b.create <arith::FPToSIOp>(loc, b.getI64Type (), xClipped);
3090
- x[k] = b.create <arith::IndexCastOp>(loc, b.getIndexType (), xInt);
3091
- }
3092
- // 1. Compute x_original and y_original (proj)
3093
- // 2. Compute nearest x and y neighbors
3094
- // 3. Compute Wx Wy
3095
- // 4. Extract inputs at nearest neighbors (inputExtracts)
3096
- // 5. Compute weighted sum (yield this)
3097
-
3098
- // 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
3099
- // 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
3100
- // Sum_x is over 4 nearest x neighbors (similar for Sum_y)
3101
- // f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
3102
- // * W(y_original - y)
3103
- Value fxy = zero;
3104
-
3105
- for (int j = 0 ; j < 4 ; j++) {
3106
- Value wy = wys[j];
3107
- Value xInterpy = zero;
3108
-
3109
- indices[dimOffset] = y[j];
3110
-
3111
- for (int i = 0 ; i < 4 ; i++) {
3112
- Value wx = wxs[i];
3113
-
3114
- indices[dimOffset + 1 ] = x[i];
3115
-
3116
- Value p = b.create <tensor::ExtractOp>(loc, input, indices);
3117
-
3118
- Value wxp = b.create <arith::MulFOp>(loc, wx, p);
3119
- xInterpy = b.create <arith::AddFOp>(loc, xInterpy, wxp);
3120
- }
3121
- Value wyXInterpy = b.create <arith::MulFOp>(loc, wy, xInterpy);
3122
- fxy = b.create <arith::AddFOp>(loc, fxy, wyXInterpy);
3123
- }
3124
-
3125
- return fxy;
3126
- }
3127
-
3128
2929
namespace {
3129
2930
class ConvertInterpolateOp
3130
2931
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -3140,8 +2941,7 @@ class ConvertInterpolateOp
3140
2941
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
3141
2942
// op with the non-standard mode="bilinear_asymmetric".
3142
2943
matchPattern (op.getMode (), m_TorchConstantStr (mode));
3143
- if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" &&
3144
- mode.substr (0 , 5 ) != " cubic" ) {
2944
+ if (mode.substr (0 , 8 ) != " bilinear" && mode.substr (0 , 7 ) != " nearest" ) {
3145
2945
return failure ();
3146
2946
}
3147
2947
@@ -3223,18 +3023,13 @@ class ConvertInterpolateOp
3223
3023
(mode.find (" ," ) == std::string::npos)
3224
3024
? " "
3225
3025
: mode.substr (mode.find (" ," ) + 1 );
3226
- retVal = nearestInterpolate (
3026
+ retVal = NearestInterpolate (
3227
3027
b, loc, outputSizeIntValues, input, inputSizes,
3228
3028
ScaleFactorFloatValues, coordTfMode, nearestMode);
3229
3029
} else if (mode.substr (0 , 8 ) == " bilinear" ) {
3230
- retVal = bilinearInterpolate (
3030
+ retVal = BilinearInterpolate (
3231
3031
b, op, loc, outputSizeIntValues, input, inputSizes,
3232
3032
ScaleFactorFloatValues, mode.substr (8 ));
3233
- } else if (mode.substr (0 , 5 ) == " cubic" ) {
3234
-
3235
- retVal = bicubicInterpolate (
3236
- b, op, loc, outputSizeIntValues, input, inputSizes,
3237
- ScaleFactorFloatValues, mode.substr (5 ));
3238
3033
}
3239
3034
b.create <linalg::YieldOp>(loc, retVal);
3240
3035
})
0 commit comments