Skip to content

Commit 612ccc3

Browse files
Revert "Add Scalarization Patterns for AtenToDtypeOp, AtenNegOp, AtenRemainderTensorOp (#3861)"
This reverts commit cd38ecf.
1 parent c1c0524 commit 612ccc3

File tree

6 files changed

+30
-302
lines changed

6 files changed

+30
-302
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,6 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
8282
};
8383
} // namespace
8484

85-
namespace {
86-
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
87-
public:
88-
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
89-
LogicalResult
90-
matchAndRewrite(AtenNegIntOp op,
91-
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
92-
ConversionPatternRewriter &rewriter) const override {
93-
Value a = adaptor.getA();
94-
rewriter.replaceOpWithNewOp<arith::SubIOp>(
95-
op,
96-
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
97-
/*bitwidth=*/64),
98-
a);
99-
return success();
100-
}
101-
};
102-
} // namespace
103-
10485
namespace {
10586
template <typename AtenOp, typename UnaryOp>
10687
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
@@ -484,14 +465,11 @@ class ConvertTorchToArith
484465

485466
target.addIllegalOp<AtenAddOp>();
486467
patterns.add<ConvertAtenAddOp>(typeConverter, context);
487-
target.addIllegalOp<AtenNegIntOp>();
488-
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
468+
489469
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
490-
AtenMulIntOp, AtenRemainderIntOp>();
470+
AtenMulIntOp>();
491471
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
492472
typeConverter, context);
493-
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
494-
typeConverter, context);
495473
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
496474
typeConverter, context);
497475
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4068,10 +4068,6 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
40684068
int64_t lhs, rhs;
40694069
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
40704070
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
4071-
if (lConstant && lhs == 1)
4072-
return getOperand(1);
4073-
if (rConstant && rhs == 1)
4074-
return getOperand(0);
40754071
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
40764072
return getI64IntegerAttr(getContext(), 0);
40774073
if (lConstant && rConstant)

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4587,11 +4587,6 @@ class DecomposeAtenUnflattenIntOp
45874587
if (!isValidDim(dimInt, inputRank))
45884588
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
45894589

4590-
if (inputShape[dimInt] == Torch::kUnknownSize &&
4591-
llvm::count(sizesInts, -1) > 0)
4592-
return rewriter.notifyMatchFailure(
4593-
op, "Unimplemented: dynamic unflatten dim with an inferred size.");
4594-
45954590
SmallVector<Value> sizesTorchInt;
45964591
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
45974592
return rewriter.notifyMatchFailure(

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

Lines changed: 12 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
714714
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
715715

716716
// Rank 0 item op prop
717-
if (selfTy.getSizes().empty()) {
717+
if (selfTy.getSizes().size() == 0) {
718718
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
719719
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
720720
if (!squeezeDim && !numToTensor)
@@ -746,109 +746,6 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
746746
};
747747
} // namespace
748748

749-
namespace {
750-
751-
LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
752-
SmallVector<OpFoldResult> &converted,
753-
SmallVector<OpFoldResult> &elements,
754-
Type inputDtype, Type resultDtype) {
755-
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
756-
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
757-
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
758-
return failure();
759-
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
760-
return failure();
761-
762-
// if dtypes are both int or both float, no conversion needed
763-
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
764-
converted = elements;
765-
return success();
766-
}
767-
768-
if (resultIsInt) {
769-
for (auto &e : elements) {
770-
auto eValue = dyn_cast<Value>(e);
771-
if (eValue) {
772-
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
773-
continue;
774-
}
775-
auto eAttr = dyn_cast<Attribute>(e);
776-
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
777-
if (!eFloatAttr)
778-
return failure();
779-
780-
converted.push_back(IntegerAttr::get(
781-
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
782-
}
783-
return success();
784-
}
785-
786-
// result is float
787-
for (auto &e : elements) {
788-
auto eValue = dyn_cast<Value>(e);
789-
if (eValue) {
790-
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
791-
continue;
792-
}
793-
auto eAttr = dyn_cast<Attribute>(e);
794-
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
795-
if (!eIntAttr)
796-
return failure();
797-
798-
auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
799-
: eIntAttr.getValue().getZExtValue();
800-
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
801-
}
802-
return success();
803-
}
804-
805-
class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
806-
public:
807-
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
808-
LogicalResult matchAndRewrite(AtenToDtypeOp op,
809-
PatternRewriter &rewriter) const override {
810-
bool nonBlocking, copyArg;
811-
// The non_blocking arg must be `False`.
812-
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
813-
nonBlocking)
814-
return failure();
815-
// The copy arg must be `False`.
816-
if (!matchPattern(op.getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
817-
return failure();
818-
// The memory_format arg must be `none`.
819-
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
820-
return failure();
821-
822-
auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
823-
auto resultType = dyn_cast<ValueTensorType>(op.getType());
824-
if (!inputType || !resultType || !inputType.hasDtype() ||
825-
!resultType.hasDtype())
826-
return failure();
827-
auto inputDtype = inputType.getDtype();
828-
auto resultDtype = resultType.getDtype();
829-
830-
SmallVector<OpFoldResult> elements;
831-
if (failed(getListFromTensor(op.getSelf(), elements)))
832-
return failure();
833-
834-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
835-
SmallVector<OpFoldResult> converted;
836-
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
837-
resultDtype)))
838-
return rewriter.notifyMatchFailure(
839-
op, "Unhandled attribute type encountered.");
840-
841-
SmallVector<Value> vals;
842-
if (failed(materializeFolds(b, converted, vals)))
843-
return failure();
844-
845-
Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
846-
rewriter.replaceOp(op, result);
847-
return success();
848-
}
849-
};
850-
} // namespace
851-
852749
namespace {
853750
template <typename AtenViewLikeOp>
854751
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
@@ -931,49 +828,6 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
931828
if (failed(materializeFolds(b, resultFolds, resultVals)))
932829
return failure();
933830

934-
if (resultTy.getSizes().empty()) {
935-
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
936-
op, resultTy, resultVals.front());
937-
return success();
938-
}
939-
940-
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
941-
rewriter.replaceOp(op, result);
942-
return success();
943-
}
944-
};
945-
} // namespace
946-
947-
namespace {
948-
template <typename OpTy, typename ScalarOpTy>
949-
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
950-
public:
951-
using OpRewritePattern<OpTy>::OpRewritePattern;
952-
LogicalResult matchAndRewrite(OpTy op,
953-
PatternRewriter &rewriter) const override {
954-
// Check type
955-
auto resultTy = cast<ValueTensorType>(op.getType());
956-
if (resultTy.getSizes().size() > 1)
957-
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
958-
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
959-
return rewriter.notifyMatchFailure(op, "not an int type");
960-
961-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
962-
SmallVector<OpFoldResult> selfFold;
963-
if (failed(getListFromTensor(op.getSelf(), selfFold)))
964-
return failure();
965-
SmallVector<Value> selfVals;
966-
if (failed(materializeFolds(b, selfFold, selfVals)))
967-
return failure();
968-
SmallVector<OpFoldResult> resultFolds;
969-
for (uint64_t i = 0; i < selfVals.size(); i++) {
970-
resultFolds.push_back(
971-
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
972-
}
973-
SmallVector<Value> resultVals;
974-
if (failed(materializeFolds(b, resultFolds, resultVals)))
975-
return failure();
976-
977831
if (resultTy.getSizes().size() == 0) {
978832
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
979833
op, resultTy, resultVals.front());
@@ -986,6 +840,7 @@ class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
986840
}
987841
};
988842
} // namespace
843+
989844
/// ------ Fold Patterns ------ ///
990845
// These are shape-specific folding patterns
991846

@@ -1060,22 +915,19 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
1060915
auto resultTy = cast<BaseTensorType>(op.getType());
1061916
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
1062917
return rewriter.notifyMatchFailure(op, "dynamic output shape");
1063-
if (resultTy.getSizes().size() == 0) {
1064-
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
1065-
op, op.getType(), elements.front());
1066-
return success();
1067-
}
1068918

1069919
auto loc = op.getLoc();
1070920
SmallVector<Value> sizes;
1071921
for (auto size : resultTy.getSizes())
1072922
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
1073923
loc, rewriter.getI64IntegerAttr(size)));
1074924

925+
Value one = rewriter.create<Torch::ConstantIntOp>(
926+
loc, rewriter.getType<Torch::IntType>(), 1);
1075927
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
1076928
loc,
1077929
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
1078-
sizes);
930+
one);
1079931

1080932
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1081933
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
@@ -1179,24 +1031,6 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
11791031
};
11801032
} // namespace
11811033

1182-
namespace {
1183-
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
1184-
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
1185-
public:
1186-
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
1187-
LogicalResult matchAndRewrite(AtenIntScalarOp op,
1188-
PatternRewriter &rewriter) const override {
1189-
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
1190-
if (!floatScalarOp)
1191-
return failure();
1192-
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
1193-
if (!sizeOp)
1194-
return failure();
1195-
rewriter.replaceOp(op, floatScalarOp.getA());
1196-
return success();
1197-
}
1198-
};
1199-
} // namespace
12001034
namespace {
12011035
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
12021036
public:
@@ -1348,29 +1182,8 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
13481182
if (inputUnmatched == 1 && outputUnmatched > 1) {
13491183
Value dimVal =
13501184
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
1351-
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
1352-
viewSizes.end() - rightMatchEnd);
1353-
// try to convert a single dynamic size input to -1
1354-
int64_t dynCount = 0;
1355-
int64_t dynIdx = 0;
1356-
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
1357-
int64_t szeInt;
1358-
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
1359-
dynCount++;
1360-
dynIdx = i;
1361-
continue;
1362-
}
1363-
// if we have a -1 already, make dynCount invalid and break
1364-
if (szeInt == -1) {
1365-
dynCount = -1;
1366-
break;
1367-
}
1368-
}
1369-
// if only one size is dynamic, make it -1
1370-
if (dynCount == 1)
1371-
unflattenSizes[dynIdx] =
1372-
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
1373-
1185+
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
1186+
viewSizes.end() - rightMatchEnd);
13741187
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
13751188
op.getLoc(), op.getSize().getType(), unflattenSizes);
13761189
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
@@ -1414,18 +1227,6 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
14141227

14151228
namespace {
14161229

1417-
bool isItemForSliceOp(Operation *op) {
1418-
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
1419-
if (!itemOp)
1420-
return false;
1421-
for (OpOperand &use : op->getUses()) {
1422-
Operation *userOp = use.getOwner();
1423-
if (isa<AtenSliceTensorOp>(userOp))
1424-
return true;
1425-
}
1426-
return false;
1427-
}
1428-
14291230
bool isSourceOpForShapeScalarization(Operation *op) {
14301231
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
14311232
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
@@ -1443,7 +1244,7 @@ bool isPrimListOfInts(Operation *op) {
14431244

14441245
bool isAnchorOp(Operation *op) {
14451246
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1446-
isPrimListOfInts(op) || isItemForSliceOp(op);
1247+
isPrimListOfInts(op);
14471248
}
14481249

14491250
// The argument to this function, op, is the use of some source op, srcOp. If
@@ -1477,9 +1278,9 @@ bool isInvalidValidViewConsumer(Operation *op,
14771278
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
14781279
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
14791280
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1480-
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
1481-
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
1482-
FoldAtenEqIntPattern>(patterns.getContext());
1281+
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
1282+
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
1283+
patterns.getContext());
14831284
}
14841285

14851286
void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
@@ -1502,29 +1303,24 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
15021303
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
15031304
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
15041305
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1505-
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
1506-
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
1306+
PropagateAtenTransposeIntPattern,
15071307
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
15081308
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
15091309
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1510-
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
15111310
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
15121311
patterns.getContext());
15131312
}
15141313

15151314
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
15161315
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
15171316
RemoveUnusedPattern<Torch::AtenEqIntOp>,
1518-
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
15191317
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
15201318
RemoveUnusedPattern<Torch::AtenFullOp>,
15211319
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
15221320
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
15231321
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
15241322
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
15251323
RemoveUnusedPattern<Torch::AtenTensorOp>,
1526-
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
1527-
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
15281324
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
15291325
patterns.getContext());
15301326
}

0 commit comments

Comments
 (0)