Skip to content

Commit 98e4eb2

Browse files
[TOSA] Add lowering for aten.expm1 (#3949)
* Add Torch to TOSA legalization for aten.expm1 * Update xfail_sets with new test results * Add new LIT tests Change-Id: I834d0c7416341f884612053aebf9fcc90bcb3b53 Signed-off-by: Justin Ngo <[email protected]>
1 parent a45356e commit 98e4eb2

File tree

3 files changed

+79
-4
lines changed

3 files changed

+79
-4
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -8212,6 +8212,47 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
82128212
return success();
82138213
}
82148214

8215+
// Legalization for aten.expm1
8216+
template <>
8217+
LogicalResult ConvertAtenOp<AtenExpm1Op>::matchAndRewrite(
8218+
AtenExpm1Op op, OpAdaptor adaptor,
8219+
ConversionPatternRewriter &rewriter) const {
8220+
// expm1 formula:
8221+
// yi = exp(x) - 1
8222+
// Note: This lowering might not provide as great precision as aten.expm1
8223+
// since TOSA doesn't have a built-in expm1 op.
8224+
auto self = adaptor.getSelf();
8225+
8226+
auto selfType = dyn_cast<TensorType>(self.getType());
8227+
if (!selfType)
8228+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8229+
8230+
auto resultType =
8231+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8232+
auto resultElemTy = resultType.getElementType();
8233+
8234+
if (!isa<mlir::FloatType>(resultElemTy))
8235+
return rewriter.notifyMatchFailure(
8236+
op, "Only floating-point datatype result types are supported");
8237+
8238+
// If input is not a float type then cast it to result element type
8239+
auto selfElemTy = selfType.getElementType();
8240+
if (!isa<mlir::FloatType>(selfElemTy))
8241+
self = tosa::promoteType(rewriter, self, resultType);
8242+
8243+
auto one =
8244+
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();
8245+
8246+
auto expOp = rewriter.create<tosa::ExpOp>(op->getLoc(), resultType, self);
8247+
8248+
auto result = rewriter.create<tosa::SubOp>(op->getLoc(), resultType,
8249+
expOp.getResult(), one);
8250+
8251+
rewriter.replaceOp(op, {result.getResult()});
8252+
8253+
return success();
8254+
}
8255+
82158256
// Legalization for aten.tan
82168257
template <>
82178258
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
@@ -8805,6 +8846,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
88058846
INSERT_ATENOP_PATTERN(AtenLogitOp);
88068847
INSERT_ATENOP_PATTERN(AtenLog1pOp);
88078848
INSERT_ATENOP_PATTERN(AtenLog10Op);
8849+
INSERT_ATENOP_PATTERN(AtenExpm1Op);
88088850
INSERT_ATENOP_PATTERN(AtenTanOp);
88098851
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
88108852
#undef INSERT_ATENOP_PATTERN

projects/pt1/e2e_testing/xfail_sets.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1709,8 +1709,12 @@
17091709
"Unfold_Module_Rank_Zero_basic",
17101710
"Unfold_Module_basic",
17111711
"ElementwiseErfIntModule_basic",
1712+
"ElementwiseExpm1IntModule_basic",
1713+
"ElementwiseExpm1Module_basic",
17121714
"ElementwiseIntTensorLtFloatScalarModule_basic",
17131715
"ElementwiseSigmoidIntModule_basic",
1716+
"ElementwiseSpecialExpm1IntModule_basic",
1717+
"ElementwiseSpecialExpm1Module_basic",
17141718
"ElementwiseTanIntModule_basic",
17151719
"ElementwiseTanModule_basic",
17161720
"ElementwiseUnaryIntModule_basic",
@@ -3668,16 +3672,12 @@
36683672
"ElementwiseCoshModule_basic",
36693673
"ElementwiseDequantizePerChannelModule_basic",
36703674
"ElementwiseDequantizePerTensorModule_basic",
3671-
"ElementwiseExpm1IntModule_basic",
3672-
"ElementwiseExpm1Module_basic",
36733675
"ElementwiseMulTensorComplexDiffModule_basic",
36743676
"ElementwiseMulTensorComplexModule_basic",
36753677
"ElementwiseQuantizePerTensorModule_basic",
36763678
"ElementwiseQuantizePerTensorUIntModule_basic",
36773679
"ElementwiseSinhIntModule_basic",
36783680
"ElementwiseSinhModule_basic",
3679-
"ElementwiseSpecialExpm1IntModule_basic",
3680-
"ElementwiseSpecialExpm1Module_basic",
36813681
"ElementwiseToDtypeF32ToI64Module_basic",
36823682
"ElementwiseToDtypeI64ToUI8Module_basic",
36833683
"ElementwiseWhereScalarOtherStaticModule_basic",

test/Conversion/TorchToTosa/basic.mlir

+33
Original file line numberDiff line numberDiff line change
@@ -3024,3 +3024,36 @@ func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.
30243024
}
30253025

30263026
// -----
3027+
3028+
// CHECK-LABEL: func.func @torch.aten.expm1$basic(
3029+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
3030+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
3031+
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
3032+
// CHECK: %[[VAL_3:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
3033+
// CHECK: %[[VAL_4:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xf32>
3034+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
3035+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
3036+
// CHECK: }
3037+
func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
3038+
%0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
3039+
return %0 : !torch.vtensor<[3,4],f32>
3040+
}
3041+
3042+
// -----
3043+
3044+
// CHECK-LABEL: func.func @torch.aten.expm1$int(
3045+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> {
3046+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32>
3047+
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
3048+
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
3049+
// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
3050+
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xf32>
3051+
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
3052+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
3053+
// CHECK: }
3054+
func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> {
3055+
%0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32>
3056+
return %0 : !torch.vtensor<[3,4],f32>
3057+
}
3058+
3059+
// -----

0 commit comments

Comments
 (0)