@@ -10471,6 +10471,89 @@ class DecomposeAtenNllLossForwardOp
1047110471};
1047210472} // namespace
1047310473
10474+ namespace {
10475+ // Decompostion of aten.hinge_embedding_loss op
10476+ // Ref:
10477+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182
10478+ // The Hinge Embedding Loss:
10479+ // | input, if target == 1
10480+ // loss(x) = |
10481+ // | max(0, margin - input), if target == -1
10482+ class DecomposeHingeEmbeddingLoss
10483+ : public OpRewritePattern<AtenHingeEmbeddingLossOp> {
10484+ using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
10485+ LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op,
10486+ PatternRewriter &rewriter) const override {
10487+ Location loc = op.getLoc();
10488+ auto input = op.getSelf();
10489+ auto target = op.getTarget();
10490+
10491+ auto inputTy = dyn_cast<ValueTensorType>(input.getType());
10492+ if (!inputTy.hasDtype() || !inputTy.hasSizes())
10493+ return rewriter.notifyMatchFailure(op, "input must have dtype and size");
10494+
10495+ auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10496+ if (!targetTy.hasDtype() || !targetTy.hasSizes())
10497+ return rewriter.notifyMatchFailure(op, "target must have dtype and size");
10498+ auto resultTy = dyn_cast<ValueTensorType>(op.getType());
10499+ Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
10500+ targetTy.getDtype());
10501+ Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
10502+ targetTy.getDtype());
10503+ Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
10504+ targetTy.getDtype());
10505+ Value alpha =
10506+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10507+ auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(),
10508+ rewriter.getI1Type());
10509+ // input - margin
10510+ auto inputMinusMargin = rewriter.create<AtenSubScalarOp>(
10511+ loc, inputTy, input, op.getMargin(), alpha);
10512+ // multiply by -1 to get margin - input
10513+ auto marginDiff = rewriter.create<AtenMulScalarOp>(
10514+ loc, inputTy, inputMinusMargin, minusOne);
10515+ // max(0, margin - input) => clamping the minimum value of margin - input at
10516+ // 0
10517+ auto marginClamp =
10518+ rewriter.create<AtenClampMinOp>(loc, inputTy, marginDiff, zero);
10519+ // Compute mask: target != 1
10520+ auto targetNotOne =
10521+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
10522+ // If target != -1 use marginClamp otherwise 0.
10523+ auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
10524+ loc, inputTy, targetNotOne, marginClamp, zero);
10525+ // Compute mask: target != 1
10526+ auto targetNotMinusOne =
10527+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
10528+ // If target != 1 use the original input. Otherwise 0.
10529+ auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
10530+ loc, inputTy, targetNotMinusOne, input, zero);
10531+ // Add : outputMargin + outputSelf
10532+ auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
10533+ outputSelf, /*alpha=*/alpha);
10534+ int64_t reduction;
10535+ if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10536+ return rewriter.notifyMatchFailure(op,
10537+ "reduction should be a constant int!");
10538+ }
10539+ Value loss;
10540+ Value none = rewriter.create<ConstantNoneOp>(loc);
10541+ // reduction: mean
10542+ if (reduction == 1) {
10543+ loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
10544+ } else if (reduction == 2) {
10545+ // reduction: sum
10546+ loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
10547+ } else {
10548+ // reduction: none
10549+ loss = output;
10550+ }
10551+ rewriter.replaceOp(op, loss);
10552+ return success();
10553+ }
10554+ };
10555+ } // namespace
10556+
1047410557namespace {
1047510558class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1047610559 : public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12384,6 +12467,7 @@ class DecomposeComplexOpsPass
1238412467 addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
1238512468 addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
1238612469 addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12470+ addPatternIfTargetOpIsIllegal<DecomposeHingeEmbeddingLoss>(patterns);
1238712471 addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1238812472 patterns);
1238912473 addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
0 commit comments