@@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
714
714
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
715
715
716
716
// Rank 0 item op prop
717
- if (selfTy.getSizes ().empty () ) {
717
+ if (selfTy.getSizes ().size () == 0 ) {
718
718
auto numToTensor = self.getDefiningOp <Torch::PrimNumToTensorScalarOp>();
719
719
auto squeezeDim = self.getDefiningOp <AtenSqueezeDimOp>();
720
720
if (!squeezeDim && !numToTensor)
@@ -746,109 +746,6 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
746
746
};
747
747
} // namespace
748
748
749
- namespace {
750
-
751
- LogicalResult convertOpFoldResults (ImplicitLocOpBuilder &b,
752
- SmallVector<OpFoldResult> &converted,
753
- SmallVector<OpFoldResult> &elements,
754
- Type inputDtype, Type resultDtype) {
755
- auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
756
- auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
757
- if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
758
- return failure ();
759
- if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
760
- return failure ();
761
-
762
- // if dtypes are both int or both float, no conversion needed
763
- if (static_cast <bool >(inputIsInt) == static_cast <bool >(resultIsInt)) {
764
- converted = elements;
765
- return success ();
766
- }
767
-
768
- if (resultIsInt) {
769
- for (auto &e : elements) {
770
- auto eValue = dyn_cast<Value>(e);
771
- if (eValue) {
772
- converted.push_back (b.createOrFold <AtenIntScalarOp>(eValue));
773
- continue ;
774
- }
775
- auto eAttr = dyn_cast<Attribute>(e);
776
- auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
777
- if (!eFloatAttr)
778
- return failure ();
779
-
780
- converted.push_back (IntegerAttr::get (
781
- resultDtype, static_cast <int64_t >(eFloatAttr.getValueAsDouble ())));
782
- }
783
- return success ();
784
- }
785
-
786
- // result is float
787
- for (auto &e : elements) {
788
- auto eValue = dyn_cast<Value>(e);
789
- if (eValue) {
790
- converted.push_back (b.createOrFold <AtenFloatScalarOp>(eValue));
791
- continue ;
792
- }
793
- auto eAttr = dyn_cast<Attribute>(e);
794
- auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
795
- if (!eIntAttr)
796
- return failure ();
797
-
798
- auto eInt = (inputIsInt.isSigned ()) ? eIntAttr.getValue ().getSExtValue ()
799
- : eIntAttr.getValue ().getZExtValue ();
800
- converted.push_back (FloatAttr::get (resultDtype, static_cast <double >(eInt)));
801
- }
802
- return success ();
803
- }
804
-
805
- class PropagateAtenToDtypePattern : public OpRewritePattern <AtenToDtypeOp> {
806
- public:
807
- using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
808
- LogicalResult matchAndRewrite (AtenToDtypeOp op,
809
- PatternRewriter &rewriter) const override {
810
- bool nonBlocking, copyArg;
811
- // The non_blocking arg must be `False`.
812
- if (!matchPattern (op.getNonBlocking (), m_TorchConstantBool (&nonBlocking)) ||
813
- nonBlocking)
814
- return failure ();
815
- // The copy arg must be `False`.
816
- if (!matchPattern (op.getCopy (), m_TorchConstantBool (©Arg)) || copyArg)
817
- return failure ();
818
- // The memory_format arg must be `none`.
819
- if (!isa<Torch::NoneType>(op.getMemoryFormat ().getType ()))
820
- return failure ();
821
-
822
- auto inputType = dyn_cast<ValueTensorType>(op.getSelf ().getType ());
823
- auto resultType = dyn_cast<ValueTensorType>(op.getType ());
824
- if (!inputType || !resultType || !inputType.hasDtype () ||
825
- !resultType.hasDtype ())
826
- return failure ();
827
- auto inputDtype = inputType.getDtype ();
828
- auto resultDtype = resultType.getDtype ();
829
-
830
- SmallVector<OpFoldResult> elements;
831
- if (failed (getListFromTensor (op.getSelf (), elements)))
832
- return failure ();
833
-
834
- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
835
- SmallVector<OpFoldResult> converted;
836
- if (failed (convertOpFoldResults (b, converted, elements, inputDtype,
837
- resultDtype)))
838
- return rewriter.notifyMatchFailure (
839
- op, " Unhandled attribute type encountered." );
840
-
841
- SmallVector<Value> vals;
842
- if (failed (materializeFolds (b, converted, vals)))
843
- return failure ();
844
-
845
- Value result = constructAtenTensorOpFromList (b, op.getType (), vals);
846
- rewriter.replaceOp (op, result);
847
- return success ();
848
- }
849
- };
850
- } // namespace
851
-
852
749
namespace {
853
750
template <typename AtenViewLikeOp>
854
751
class PropagateAtenViewLikePattern : public OpRewritePattern <AtenViewLikeOp> {
@@ -931,49 +828,6 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
931
828
if (failed (materializeFolds (b, resultFolds, resultVals)))
932
829
return failure ();
933
830
934
- if (resultTy.getSizes ().empty ()) {
935
- rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
936
- op, resultTy, resultVals.front ());
937
- return success ();
938
- }
939
-
940
- Value result = constructAtenTensorOpFromList (b, resultTy, resultVals);
941
- rewriter.replaceOp (op, result);
942
- return success ();
943
- }
944
- };
945
- } // namespace
946
-
947
- namespace {
948
- template <typename OpTy, typename ScalarOpTy>
949
- class PropagateAtenUnaryPattern : public OpRewritePattern <OpTy> {
950
- public:
951
- using OpRewritePattern<OpTy>::OpRewritePattern;
952
- LogicalResult matchAndRewrite (OpTy op,
953
- PatternRewriter &rewriter) const override {
954
- // Check type
955
- auto resultTy = cast<ValueTensorType>(op.getType ());
956
- if (resultTy.getSizes ().size () > 1 )
957
- return rewriter.notifyMatchFailure (op, " unsupported: rank > 1" );
958
- if (!resultTy.hasDtype () || !isa<mlir::IntegerType>(resultTy.getDtype ()))
959
- return rewriter.notifyMatchFailure (op, " not an int type" );
960
-
961
- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
962
- SmallVector<OpFoldResult> selfFold;
963
- if (failed (getListFromTensor (op.getSelf (), selfFold)))
964
- return failure ();
965
- SmallVector<Value> selfVals;
966
- if (failed (materializeFolds (b, selfFold, selfVals)))
967
- return failure ();
968
- SmallVector<OpFoldResult> resultFolds;
969
- for (uint64_t i = 0 ; i < selfVals.size (); i++) {
970
- resultFolds.push_back (
971
- b.createOrFold <ScalarOpTy>(selfVals[i].getType (), selfVals[i]));
972
- }
973
- SmallVector<Value> resultVals;
974
- if (failed (materializeFolds (b, resultFolds, resultVals)))
975
- return failure ();
976
-
977
831
if (resultTy.getSizes ().size () == 0 ) {
978
832
rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
979
833
op, resultTy, resultVals.front ());
@@ -986,6 +840,7 @@ class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
986
840
}
987
841
};
988
842
} // namespace
843
+
989
844
// / ------ Fold Patterns ------ ///
990
845
// These are shape-specific folding patterns
991
846
@@ -1060,22 +915,19 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
1060
915
auto resultTy = cast<BaseTensorType>(op.getType ());
1061
916
if (!resultTy.hasSizes () || !resultTy.areAllSizesKnown ())
1062
917
return rewriter.notifyMatchFailure (op, " dynamic output shape" );
1063
- if (resultTy.getSizes ().size () == 0 ) {
1064
- rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
1065
- op, op.getType (), elements.front ());
1066
- return success ();
1067
- }
1068
918
1069
919
auto loc = op.getLoc ();
1070
920
SmallVector<Value> sizes;
1071
921
for (auto size : resultTy.getSizes ())
1072
922
sizes.push_back (rewriter.create <Torch::ConstantIntOp>(
1073
923
loc, rewriter.getI64IntegerAttr (size)));
1074
924
925
+ Value one = rewriter.create <Torch::ConstantIntOp>(
926
+ loc, rewriter.getType <Torch::IntType>(), 1 );
1075
927
Value sizeList = rewriter.create <Torch::PrimListConstructOp>(
1076
928
loc,
1077
929
rewriter.getType <Torch::ListType>(rewriter.getType <Torch::IntType>()),
1078
- sizes );
930
+ one );
1079
931
1080
932
Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1081
933
Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
@@ -1179,24 +1031,6 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
1179
1031
};
1180
1032
} // namespace
1181
1033
1182
- namespace {
1183
- // fold ridiculous patterns like size.int -> float.scalar -> int.scalar
1184
- class FoldAtenIntScalarPattern : public OpRewritePattern <AtenIntScalarOp> {
1185
- public:
1186
- using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
1187
- LogicalResult matchAndRewrite (AtenIntScalarOp op,
1188
- PatternRewriter &rewriter) const override {
1189
- auto floatScalarOp = op.getA ().getDefiningOp <AtenFloatScalarOp>();
1190
- if (!floatScalarOp)
1191
- return failure ();
1192
- auto sizeOp = floatScalarOp.getA ().getDefiningOp <AtenSizeIntOp>();
1193
- if (!sizeOp)
1194
- return failure ();
1195
- rewriter.replaceOp (op, floatScalarOp.getA ());
1196
- return success ();
1197
- }
1198
- };
1199
- } // namespace
1200
1034
namespace {
1201
1035
class FoldAtenUnsqueezePattern : public OpRewritePattern <AtenUnsqueezeOp> {
1202
1036
public:
@@ -1348,29 +1182,8 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
1348
1182
if (inputUnmatched == 1 && outputUnmatched > 1 ) {
1349
1183
Value dimVal =
1350
1184
rewriter.create <Torch::ConstantIntOp>(op.getLoc (), leftMatchEnd);
1351
- SmallVector<Value> unflattenSizes (viewSizes.begin () + leftMatchEnd,
1352
- viewSizes.end () - rightMatchEnd);
1353
- // try to convert a single dynamic size input to -1
1354
- int64_t dynCount = 0 ;
1355
- int64_t dynIdx = 0 ;
1356
- for (auto [i, v] : llvm::enumerate (unflattenSizes)) {
1357
- int64_t szeInt;
1358
- if (!matchPattern (v, m_TorchConstantInt (&szeInt))) {
1359
- dynCount++;
1360
- dynIdx = i;
1361
- continue ;
1362
- }
1363
- // if we have a -1 already, make dynCount invalid and break
1364
- if (szeInt == -1 ) {
1365
- dynCount = -1 ;
1366
- break ;
1367
- }
1368
- }
1369
- // if only one size is dynamic, make it -1
1370
- if (dynCount == 1 )
1371
- unflattenSizes[dynIdx] =
1372
- rewriter.create <Torch::ConstantIntOp>(op.getLoc (), -1 );
1373
-
1185
+ ArrayRef<Value> unflattenSizes (viewSizes.begin () + leftMatchEnd,
1186
+ viewSizes.end () - rightMatchEnd);
1374
1187
Value unflattenList = rewriter.create <Torch::PrimListConstructOp>(
1375
1188
op.getLoc (), op.getSize ().getType (), unflattenSizes);
1376
1189
rewriter.replaceOpWithNewOp <AtenUnflattenIntOp>(
@@ -1414,18 +1227,6 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
1414
1227
1415
1228
namespace {
1416
1229
1417
- bool isItemForSliceOp (Operation *op) {
1418
- auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
1419
- if (!itemOp)
1420
- return false ;
1421
- for (OpOperand &use : op->getUses ()) {
1422
- Operation *userOp = use.getOwner ();
1423
- if (isa<AtenSliceTensorOp>(userOp))
1424
- return true ;
1425
- }
1426
- return false ;
1427
- }
1428
-
1429
1230
bool isSourceOpForShapeScalarization (Operation *op) {
1430
1231
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
1431
1232
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
@@ -1443,7 +1244,7 @@ bool isPrimListOfInts(Operation *op) {
1443
1244
1444
1245
bool isAnchorOp (Operation *op) {
1445
1246
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1446
- isPrimListOfInts (op) || isItemForSliceOp (op) ;
1247
+ isPrimListOfInts (op);
1447
1248
}
1448
1249
1449
1250
// The argument to this function, op, is the use of some source op, srcOp. If
@@ -1477,9 +1278,9 @@ bool isInvalidValidViewConsumer(Operation *op,
1477
1278
void populateScalarizationFoldPatterns (RewritePatternSet &patterns) {
1478
1279
patterns.insert <FoldAtenSqueezePattern<AtenSqueezeOp>,
1479
1280
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1480
- FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern ,
1481
- FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
1482
- FoldAtenEqIntPattern>( patterns.getContext ());
1281
+ FoldAtenUnsqueezePattern, FoldAtenWhereSelf ,
1282
+ FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
1283
+ patterns.getContext ());
1483
1284
}
1484
1285
1485
1286
void populateScalarizationCanonicalizePatterns (RewritePatternSet &patterns) {
@@ -1502,29 +1303,24 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
1502
1303
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1503
1304
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1504
1305
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1505
- PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
1506
- PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
1306
+ PropagateAtenTransposeIntPattern,
1507
1307
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
1508
1308
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
1509
1309
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1510
- PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
1511
1310
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
1512
1311
patterns.getContext ());
1513
1312
}
1514
1313
1515
1314
void populateScalarizationRemovePatterns (RewritePatternSet &patterns) {
1516
1315
patterns.insert <RemoveUnusedPattern<Torch::AtenIntBoolOp>,
1517
1316
RemoveUnusedPattern<Torch::AtenEqIntOp>,
1518
- RemoveUnusedPattern<Torch::AtenToDtypeOp>,
1519
1317
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
1520
1318
RemoveUnusedPattern<Torch::AtenFullOp>,
1521
1319
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
1522
1320
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
1523
1321
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
1524
1322
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
1525
1323
RemoveUnusedPattern<Torch::AtenTensorOp>,
1526
- RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
1527
- RemoveUnusedPattern<Torch::AtenIntScalarOp>,
1528
1324
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
1529
1325
patterns.getContext ());
1530
1326
}
0 commit comments