@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
798
798
PatternRewriter &rewriter, ValueRange values,
799
799
SmallVectorImpl<Value> &remapped);
800
800
801
- // / Returns true if the given operation is ignored, and does not need to be
801
+ // / Return " true" if the given operation is ignored, and does not need to be
802
802
// / converted.
803
803
bool isOpIgnored (Operation *op) const ;
804
804
805
- // / Recursively marks the nested operations under 'op' as ignored. This
806
- // / removes them from being considered for legalization.
807
- void markNestedOpsIgnored (Operation *op);
805
+ // / Return "true" if the given operation was replaced or erased.
806
+ bool wasOpReplaced (Operation *op) const ;
808
807
809
808
// ===--------------------------------------------------------------------===//
810
809
// Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
946
945
// / Ordered list of block operations (creations, splits, motions).
947
946
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
948
947
949
- // / A set of operations that should no longer be considered for legalization,
950
- // / but were not directly replace/erased/etc. by a pattern. These are
951
- // / generally child operations of other operations who were
952
- // / replaced/erased/etc. This is not meant to be an exhaustive list of all
953
- // / operations, but the minimal set that can be used to detect if a given
954
- // / operation should be `ignored`. For example, we may add the operations that
955
- // / define non-empty regions to the set, but not any of the others. This
956
- // / simplifies the amount of memory needed as we can query if the parent
957
- // / operation was ignored.
948
+ // / A set of operations that should no longer be considered for legalization.
949
+ // / E.g., ops that are recursively legal. Ops that were replaced/erased are
950
+ // / tracked separately.
958
951
SetVector<Operation *> ignoredOps;
959
952
960
- // A set of operations that were erased.
953
+ // / A set of operations that were replaced/erased. Such ops are not erased
954
+ // / immediately but only when the dialect conversion succeeds. In the mean
955
+ // / time, they should no longer be considered for legalization and any attempt
956
+ // / to modify/access them is invalid rewriter API usage.
961
957
SetVector<Operation *> replacedOps;
962
958
963
959
// / The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1237
1233
return success ();
1238
1234
}
1239
1235
1240
- // TODO: This function is a misnomer. It does not actually check if `op` is in
1241
- // `ignoredOps`.
1242
1236
bool ConversionPatternRewriterImpl::isOpIgnored (Operation *op) const {
1243
- // Check to see if this operation or the parent operation is ignored .
1244
- return ignoredOps .count (op-> getParentOp ()) || replacedOps .count (op);
1237
+ // Check to see if this operation is ignored or was replaced .
1238
+ return replacedOps .count (op) || ignoredOps .count (op);
1245
1239
}
1246
1240
1247
- void ConversionPatternRewriterImpl::markNestedOpsIgnored (Operation *op) {
1248
- // Walk this operation and collect nested operations that define non-empty
1249
- // regions. We mark such operations as 'ignored' so that we know we don't have
1250
- // to convert them, or their nested ops.
1251
- if (op->getNumRegions () == 0 )
1252
- return ;
1253
- op->walk ([&](Operation *op) {
1254
- if (llvm::any_of (op->getRegions (),
1255
- [](Region ®ion) { return !region.empty (); }))
1256
- ignoredOps.insert (op);
1257
- });
1241
+ bool ConversionPatternRewriterImpl::wasOpReplaced (Operation *op) const {
1242
+ // Check to see if this operation was replaced.
1243
+ return replacedOps.count (op);
1258
1244
}
1259
1245
1260
1246
// ===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1476
1462
logger.startLine () << " ** Insert : '" << op->getName () << " '(" << op
1477
1463
<< " )\n " ;
1478
1464
});
1465
+ assert (!wasOpReplaced (op->getParentOp ()) &&
1466
+ " attempting to insert into a block within a replaced/erased op" );
1467
+
1479
1468
if (!previous.isSet ()) {
1480
1469
// This is a newly created op.
1481
1470
appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1490
1479
void ConversionPatternRewriterImpl::notifyOpReplaced (Operation *op,
1491
1480
ValueRange newValues) {
1492
1481
assert (newValues.size () == op->getNumResults ());
1493
- assert (!replacedOps .contains (op) && " operation was already replaced" );
1482
+ assert (!ignoredOps .contains (op) && " operation was already replaced" );
1494
1483
1495
1484
// Track if any of the results changed, e.g. erased and replaced with null.
1496
1485
bool resultChanged = false ;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1509
1498
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1510
1499
resultChanged);
1511
1500
1512
- // Mark this operation as recursively ignored so that we don't need to
1513
- // convert any nested operations.
1514
- replacedOps.insert (op);
1515
- markNestedOpsIgnored (op);
1501
+ // Mark this operation and all nested ops as replaced.
1502
+ op->walk ([&](Operation *op) { replacedOps.insert (op); });
1516
1503
}
1517
1504
1518
1505
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased (Block *block) {
@@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1523
1510
1524
1511
void ConversionPatternRewriterImpl::notifyBlockInserted (
1525
1512
Block *block, Region *previous, Region::iterator previousIt) {
1513
+ assert (!wasOpReplaced (block->getParentOp ()) &&
1514
+ " attempting to insert into a region within a replaced/erased op" );
1515
+
1526
1516
if (!previous) {
1527
1517
// This is a newly created block.
1528
1518
appendRewrite<CreateBlockRewrite>(block);
@@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
1604
1594
}
1605
1595
1606
1596
void ConversionPatternRewriter::eraseBlock (Block *block) {
1597
+ assert (!impl->wasOpReplaced (block->getParentOp ()) &&
1598
+ " attempting to erase a block within a replaced/erased op" );
1599
+
1607
1600
// Mark all ops for erasure.
1608
1601
for (Operation &op : *block)
1609
1602
eraseOp (&op);
@@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
1619
1612
Block *ConversionPatternRewriter::applySignatureConversion (
1620
1613
Region *region, TypeConverter::SignatureConversion &conversion,
1621
1614
const TypeConverter *converter) {
1615
+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1616
+ " attempting to apply a signature conversion to a block within a "
1617
+ " replaced/erased op" );
1622
1618
return impl->applySignatureConversion (region, conversion, converter);
1623
1619
}
1624
1620
1625
1621
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes (
1626
1622
Region *region, const TypeConverter &converter,
1627
1623
TypeConverter::SignatureConversion *entryConversion) {
1624
+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1625
+ " attempting to apply a signature conversion to a block within a "
1626
+ " replaced/erased op" );
1628
1627
return impl->convertRegionTypes (region, converter, entryConversion);
1629
1628
}
1630
1629
1631
1630
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes (
1632
1631
Region *region, const TypeConverter &converter,
1633
1632
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1633
+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1634
+ " attempting to apply a signature conversion to a block within a "
1635
+ " replaced/erased op" );
1634
1636
return impl->convertNonEntryRegionTypes (region, converter, blockConversions);
1635
1637
}
1636
1638
@@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1665
1667
1666
1668
Block *ConversionPatternRewriter::splitBlock (Block *block,
1667
1669
Block::iterator before) {
1670
+ assert (!impl->wasOpReplaced (block->getParentOp ()) &&
1671
+ " attempting to split a block within a replaced/erased op" );
1668
1672
auto *continuation = block->splitBlock (before);
1669
1673
impl->notifySplitBlock (block, continuation);
1670
1674
return continuation;
@@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
1673
1677
void ConversionPatternRewriter::inlineBlockBefore (Block *source, Block *dest,
1674
1678
Block::iterator before,
1675
1679
ValueRange argValues) {
1680
+ #ifndef NDEBUG
1676
1681
assert (argValues.size () == source->getNumArguments () &&
1677
1682
" incorrect # of argument replacement values" );
1678
- #ifndef NDEBUG
1683
+ assert (!impl->wasOpReplaced (source->getParentOp ()) &&
1684
+ " attempting to inline a block from a replaced/erased op" );
1685
+ assert (!impl->wasOpReplaced (dest->getParentOp ()) &&
1686
+ " attempting to inline a block into a replaced/erased op" );
1679
1687
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored (op); };
1680
- #endif // NDEBUG
1681
1688
// The source block will be deleted, so it should not have any users (i.e.,
1682
1689
// there should be no predecessors).
1683
1690
assert (llvm::all_of (source->getUsers (), opIgnored) &&
1684
1691
" expected 'source' to have no predecessors" );
1692
+ #endif // NDEBUG
1685
1693
1686
1694
impl->notifyBlockBeingInlined (dest, source, before);
1687
1695
for (auto it : llvm::zip (source->getArguments (), argValues))
@@ -1691,13 +1699,17 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
1691
1699
}
1692
1700
1693
1701
void ConversionPatternRewriter::startOpModification (Operation *op) {
1702
+ assert (!impl->wasOpReplaced (op) &&
1703
+ " attempting to modify a replaced/erased op" );
1694
1704
#ifndef NDEBUG
1695
1705
impl->pendingRootUpdates .insert (op);
1696
1706
#endif
1697
1707
impl->appendRewrite <ModifyOperationRewrite>(op);
1698
1708
}
1699
1709
1700
1710
void ConversionPatternRewriter::finalizeOpModification (Operation *op) {
1711
+ assert (!impl->wasOpReplaced (op) &&
1712
+ " attempting to modify a replaced/erased op" );
1701
1713
PatternRewriter::finalizeOpModification (op);
1702
1714
// There is nothing to do here, we only need to track the operation at the
1703
1715
// start of the update.
@@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op,
1912
1924
1913
1925
// If this operation is recursively legal, mark its children as ignored so
1914
1926
// that we don't consider them for legalization.
1915
- if (legalityInfo->isRecursivelyLegal )
1916
- rewriter.getImpl ().markNestedOpsIgnored (op);
1927
+ if (legalityInfo->isRecursivelyLegal ) {
1928
+ op->walk ([&](Operation *nested) {
1929
+ if (op != nested)
1930
+ rewriter.getImpl ().ignoredOps .insert (nested);
1931
+ });
1932
+ }
1933
+
1917
1934
return success ();
1918
1935
}
1919
1936
0 commit comments