Skip to content

Commit d933e03

Browse files
committed
Add aten.logaddexp2 decomposition
Signed-off-by: Zahid Wakeel <[email protected]>
1 parent da14307 commit d933e03

File tree

8 files changed

+127
-3
lines changed

8 files changed

+127
-3
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8789,6 +8789,30 @@ def Torch_AtenLogaddexpOp : Torch_Op<"aten.logaddexp", [
87898789
}];
87908790
}
87918791

8792+
def Torch_AtenLogaddexp2Op : Torch_Op<"aten.logaddexp2", [
8793+
AllowsTypeRefinement,
8794+
HasValueSemantics,
8795+
ReadOnly
8796+
]> {
8797+
let summary = "Generated op for `aten::logaddexp2 : (Tensor, Tensor) -> (Tensor)`";
8798+
let arguments = (ins
8799+
AnyTorchTensorType:$self,
8800+
AnyTorchTensorType:$other
8801+
);
8802+
let results = (outs
8803+
AnyTorchOptionalTensorType:$result
8804+
);
8805+
let hasCustomAssemblyFormat = 1;
8806+
let extraClassDefinition = [{
8807+
ParseResult AtenLogaddexp2Op::parse(OpAsmParser &parser, OperationState &result) {
8808+
return parseDefaultTorchOp(parser, result, 2, 1);
8809+
}
8810+
void AtenLogaddexp2Op::print(OpAsmPrinter &printer) {
8811+
printDefaultTorchOp(printer, *this, 2, 1);
8812+
}
8813+
}];
8814+
}
8815+
87928816
def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [
87938817
AllowsTypeRefinement,
87948818
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9280,6 +9280,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
92809280
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
92819281
" return %0 : !torch.list<int>\n"
92829282
" }\n"
9283+
" func.func @\"__torch_mlir_shape_fn.aten.logaddexp2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9284+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9285+
" return %0 : !torch.list<int>\n"
9286+
" }\n"
92839287
" func.func @\"__torch_mlir_shape_fn.aten.masked_fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
92849288
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
92859289
" return %0 : !torch.list<int>\n"
@@ -12793,10 +12797,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1279312797
" func.func @\"__torch_mlir_dtype_fn.aten.logaddexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
1279412798
" %none = torch.constant.none\n"
1279512799
" %str = torch.constant.str \"AssertionError: \"\n"
12800+
" %false = torch.constant.bool false\n"
1279612801
" %int11 = torch.constant.int 11\n"
1279712802
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12798-
" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12799-
" torch.prim.If %1 -> () {\n"
12803+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12804+
" %2 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12805+
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12806+
" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12807+
" torch.prim.If.yield %4 : !torch.bool\n"
12808+
" } else {\n"
12809+
" torch.prim.If.yield %false : !torch.bool\n"
12810+
" }\n"
12811+
" torch.prim.If %3 -> () {\n"
12812+
" torch.prim.If.yield\n"
12813+
" } else {\n"
12814+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12815+
" torch.prim.If.yield\n"
12816+
" }\n"
12817+
" return %0#1 : !torch.int\n"
12818+
" }\n"
12819+
" func.func @\"__torch_mlir_dtype_fn.aten.logaddexp2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
12820+
" %none = torch.constant.none\n"
12821+
" %str = torch.constant.str \"AssertionError: \"\n"
12822+
" %int10 = torch.constant.int 10\n"
12823+
" %int9 = torch.constant.int 9\n"
12824+
" %int8 = torch.constant.int 8\n"
12825+
" %int11 = torch.constant.int 11\n"
12826+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12827+
" %1 = torch.prim.ListConstruct %int11, %int8, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
12828+
" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
12829+
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
12830+
" torch.prim.If %3 -> () {\n"
1280012831
" torch.prim.If.yield\n"
1280112832
" } else {\n"
1280212833
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,36 @@ class DecomposeAtenLogAddExpOp : public OpRewritePattern<AtenLogaddexpOp> {
30063006
};
30073007
} // namespace
30083008

3009+
namespace {
3010+
class DecomposeAtenLogAddExp2Op : public OpRewritePattern<AtenLogaddexp2Op> {
3011+
public:
3012+
using OpRewritePattern<AtenLogaddexp2Op>::OpRewritePattern;
3013+
LogicalResult matchAndRewrite(AtenLogaddexp2Op op,
3014+
PatternRewriter &rewriter) const override {
3015+
Location loc = op.getLoc();
3016+
Value self = op.getSelf();
3017+
Value other = op.getOther();
3018+
3019+
auto outTy = dyn_cast<ValueTensorType>(op.getType());
3020+
if (!outTy || !outTy.hasDtype() || !outTy.hasSizes()) {
3021+
return rewriter.notifyMatchFailure(op,
3022+
"output should have dtype and size");
3023+
}
3024+
3025+
Value constantOne =
3026+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
3027+
Value expSelf = rewriter.create<AtenExp2Op>(loc, outTy, self);
3028+
Value expOther = rewriter.create<AtenExp2Op>(loc, outTy, other);
3029+
Value addValue = rewriter.create<AtenAddTensorOp>(loc, outTy, expSelf,
3030+
expOther, constantOne);
3031+
Value logValue = rewriter.create<AtenLog2Op>(loc, outTy, addValue);
3032+
3033+
rewriter.replaceOp(op, logValue);
3034+
return success();
3035+
}
3036+
};
3037+
} // namespace
3038+
30093039
// SoftShrink(x, lambda) function:
30103040
// Applies a shrinkage function where:
30113041
// - If x > lambda, returns x - lambda
@@ -12051,6 +12081,7 @@ class DecomposeComplexOpsPass
1205112081
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1205212082
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
1205312083
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExpOp>(patterns);
12084+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExp2Op>(patterns);
1205412085
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1205512086
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1205612087
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
582582
target.addIllegalOp<AtenFliplrOp>();
583583
target.addIllegalOp<AtenFlipudOp>();
584584
target.addIllegalOp<AtenLogaddexpOp>();
585+
target.addIllegalOp<AtenLogaddexp2Op>();
585586

586587
for (auto &opName : backendLegalOpsSet) {
587588
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,6 +2923,7 @@
29232923
"ElementwiseExpm1IntModule_basic",
29242924
"ElementwiseExpm1Module_basic",
29252925
"ElementwiseLogAddExpModule_basic",
2926+
"ElementwiseLogAddExp2Module_basic",
29262927
"ElementwiseSpecialExpm1IntModule_basic",
29272928
"ElementwiseSpecialExpm1Module_basic",
29282929
"ElementwiseFmodTensor_Int_basic",
@@ -3931,6 +3932,7 @@
39313932
"L1LossNoReductionModule_basic",
39323933
"L1LossSumReductionModule_basic",
39333934
"ElementwiseLogAddExpModule_basic",
3935+
"ElementwiseLogAddExp2Module_basic",
39343936
"FloatPowerTensorTensorStaticModule_basic",
39353937
"IsInfiniteModule_basic",
39363938
"ElementwiseCopysignModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,9 @@ def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout
14791479
def aten〇logaddexp〡shape(self: List[int], other: List[int]) -> List[int]:
14801480
return upstream_shape_functions.unary(self)
14811481

1482+
def aten〇logaddexp2〡shape(self: List[int], other: List[int]) -> List[int]:
1483+
return upstream_shape_functions.unary(self)
1484+
14821485
def aten〇masked_fill〇Scalar〡shape(self: List[int], mask: List[int], value: float) -> List[int]:
14831486
return upstream_shape_functions.unary(self)
14841487

@@ -3475,7 +3478,14 @@ def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int,
34753478
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool}))
34763479
def aten〇logaddexp〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
34773480
self_rank, self_dtype = self_rank_dtype
3478-
assert self_dtype != torch.bool
3481+
other_rank, other_dtype = other_rank_dtype
3482+
assert self_dtype != torch.bool and other_dtype != torch.bool
3483+
return self_dtype
3484+
3485+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, torch.complex32, torch.complex64, torch.complex128}))
3486+
def aten〇logaddexp2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
3487+
self_rank, self_dtype = self_rank_dtype
3488+
assert self_dtype not in [torch.bool, torch.complex32, torch.complex64, torch.complex128]
34793489
return self_dtype
34803490

34813491
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def emit_with_mutating_variants(key, **kwargs):
723723
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
724724
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
725725
emit("aten::logaddexp : (Tensor, Tensor) -> (Tensor)")
726+
emit("aten::logaddexp2 : (Tensor, Tensor) -> (Tensor)")
726727
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
727728
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
728729
emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,6 +2604,30 @@ def ElementwiseLogAddExpModule_basic(module, tu: TestUtils):
26042604
# ==============================================================================
26052605

26062606

2607+
class ElementwiseLogAddExp2Module(torch.nn.Module):
2608+
def __init__(self):
2609+
super().__init__()
2610+
2611+
@export
2612+
@annotate_args(
2613+
[
2614+
None,
2615+
([-1, -1, -1], torch.float32, True),
2616+
([-1, -1, -1], torch.float32, True),
2617+
]
2618+
)
2619+
def forward(self, x, y):
2620+
return torch.ops.aten.logaddexp2(x, y)
2621+
2622+
2623+
@register_test_case(module_factory=lambda: ElementwiseLogAddExp2Module())
2624+
def ElementwiseLogAddExp2Module_basic(module, tu: TestUtils):
2625+
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
2626+
2627+
2628+
# ==============================================================================
2629+
2630+
26072631
class ElementwiseLogSigmoidModule(torch.nn.Module):
26082632
def __init__(self):
26092633
super().__init__()

0 commit comments

Comments
 (0)