Skip to content

Commit 6008cd4

Browse files
[mlir][Transforms] Dialect conversion: Assert when accessing erased ops (#83132)
The dialect conversion maintains sets of "ignored" and "replaced" ops. This change simplifies the two sets, such that all nested ops are included. (This was previously not the case and sometimes only the parent op was included.) This change allows for more aggressive assertions to prevent incorrect rewriter API usage. E.g., accessing ops/blocks/regions within an erased op. A concrete example: I have seen conversion patterns in downstream projects where an op is replaced with a new op, and the region of the old op is afterwards inlined into the newly created op. This is invalid rewriter API usage: ops that were replaced/erased should not be accessed. Nested ops will be considered "ignored", even if they are moved to a different region after the region's parent op was erased (which is illegal API usage). Instead, create a new op, inline the regions, then replace the old op with the new op.
1 parent 26b8be2 commit 6008cd4

File tree

2 files changed

+55
-39
lines changed

2 files changed

+55
-39
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

+55-38
Original file line numberDiff line numberDiff line change
@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
798798
PatternRewriter &rewriter, ValueRange values,
799799
SmallVectorImpl<Value> &remapped);
800800

801-
/// Returns true if the given operation is ignored, and does not need to be
801+
/// Return "true" if the given operation is ignored, and does not need to be
802802
/// converted.
803803
bool isOpIgnored(Operation *op) const;
804804

805-
/// Recursively marks the nested operations under 'op' as ignored. This
806-
/// removes them from being considered for legalization.
807-
void markNestedOpsIgnored(Operation *op);
805+
/// Return "true" if the given operation was replaced or erased.
806+
bool wasOpReplaced(Operation *op) const;
808807

809808
//===--------------------------------------------------------------------===//
810809
// Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
946945
/// Ordered list of block operations (creations, splits, motions).
947946
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
948947

949-
/// A set of operations that should no longer be considered for legalization,
950-
/// but were not directly replace/erased/etc. by a pattern. These are
951-
/// generally child operations of other operations who were
952-
/// replaced/erased/etc. This is not meant to be an exhaustive list of all
953-
/// operations, but the minimal set that can be used to detect if a given
954-
/// operation should be `ignored`. For example, we may add the operations that
955-
/// define non-empty regions to the set, but not any of the others. This
956-
/// simplifies the amount of memory needed as we can query if the parent
957-
/// operation was ignored.
948+
/// A set of operations that should no longer be considered for legalization.
949+
/// E.g., ops that are recursively legal. Ops that were replaced/erased are
950+
/// tracked separately.
958951
SetVector<Operation *> ignoredOps;
959952

960-
// A set of operations that were erased.
953+
/// A set of operations that were replaced/erased. Such ops are not erased
954+
/// immediately but only when the dialect conversion succeeds. In the mean
955+
/// time, they should no longer be considered for legalization and any attempt
956+
/// to modify/access them is invalid rewriter API usage.
961957
SetVector<Operation *> replacedOps;
962958

963959
/// The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12371233
return success();
12381234
}
12391235

1240-
// TODO: This function is a misnomer. It does not actually check if `op` is in
1241-
// `ignoredOps`.
12421236
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1243-
// Check to see if this operation or the parent operation is ignored.
1244-
return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
1237+
// Check to see if this operation is ignored or was replaced.
1238+
return replacedOps.count(op) || ignoredOps.count(op);
12451239
}
12461240

1247-
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1248-
// Walk this operation and collect nested operations that define non-empty
1249-
// regions. We mark such operations as 'ignored' so that we know we don't have
1250-
// to convert them, or their nested ops.
1251-
if (op->getNumRegions() == 0)
1252-
return;
1253-
op->walk([&](Operation *op) {
1254-
if (llvm::any_of(op->getRegions(),
1255-
[](Region &region) { return !region.empty(); }))
1256-
ignoredOps.insert(op);
1257-
});
1241+
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1242+
// Check to see if this operation was replaced.
1243+
return replacedOps.count(op);
12581244
}
12591245

12601246
//===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14761462
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
14771463
<< ")\n";
14781464
});
1465+
assert(!wasOpReplaced(op->getParentOp()) &&
1466+
"attempting to insert into a block within a replaced/erased op");
1467+
14791468
if (!previous.isSet()) {
14801469
// This is a newly created op.
14811470
appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14901479
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14911480
ValueRange newValues) {
14921481
assert(newValues.size() == op->getNumResults());
1493-
assert(!replacedOps.contains(op) && "operation was already replaced");
1482+
assert(!ignoredOps.contains(op) && "operation was already replaced");
14941483

14951484
// Track if any of the results changed, e.g. erased and replaced with null.
14961485
bool resultChanged = false;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
15091498
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
15101499
resultChanged);
15111500

1512-
// Mark this operation as recursively ignored so that we don't need to
1513-
// convert any nested operations.
1514-
replacedOps.insert(op);
1515-
markNestedOpsIgnored(op);
1501+
// Mark this operation and all nested ops as replaced.
1502+
op->walk([&](Operation *op) { replacedOps.insert(op); });
15161503
}
15171504

15181505
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
15231510

15241511
void ConversionPatternRewriterImpl::notifyBlockInserted(
15251512
Block *block, Region *previous, Region::iterator previousIt) {
1513+
assert(!wasOpReplaced(block->getParentOp()) &&
1514+
"attempting to insert into a region within a replaced/erased op");
1515+
15261516
if (!previous) {
15271517
// This is a newly created block.
15281518
appendRewrite<CreateBlockRewrite>(block);
@@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
16041594
}
16051595

16061596
void ConversionPatternRewriter::eraseBlock(Block *block) {
1597+
assert(!impl->wasOpReplaced(block->getParentOp()) &&
1598+
"attempting to erase a block within a replaced/erased op");
1599+
16071600
// Mark all ops for erasure.
16081601
for (Operation &op : *block)
16091602
eraseOp(&op);
@@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
16191612
Block *ConversionPatternRewriter::applySignatureConversion(
16201613
Region *region, TypeConverter::SignatureConversion &conversion,
16211614
const TypeConverter *converter) {
1615+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1616+
"attempting to apply a signature conversion to a block within a "
1617+
"replaced/erased op");
16221618
return impl->applySignatureConversion(region, conversion, converter);
16231619
}
16241620

16251621
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
16261622
Region *region, const TypeConverter &converter,
16271623
TypeConverter::SignatureConversion *entryConversion) {
1624+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1625+
"attempting to apply a signature conversion to a block within a "
1626+
"replaced/erased op");
16281627
return impl->convertRegionTypes(region, converter, entryConversion);
16291628
}
16301629

16311630
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
16321631
Region *region, const TypeConverter &converter,
16331632
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1633+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1634+
"attempting to apply a signature conversion to a block within a "
1635+
"replaced/erased op");
16341636
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
16351637
}
16361638

@@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
16651667

16661668
Block *ConversionPatternRewriter::splitBlock(Block *block,
16671669
Block::iterator before) {
1670+
assert(!impl->wasOpReplaced(block->getParentOp()) &&
1671+
"attempting to split a block within a replaced/erased op");
16681672
auto *continuation = block->splitBlock(before);
16691673
impl->notifySplitBlock(block, continuation);
16701674
return continuation;
@@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
16731677
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
16741678
Block::iterator before,
16751679
ValueRange argValues) {
1680+
#ifndef NDEBUG
16761681
assert(argValues.size() == source->getNumArguments() &&
16771682
"incorrect # of argument replacement values");
1678-
#ifndef NDEBUG
1683+
assert(!impl->wasOpReplaced(source->getParentOp()) &&
1684+
"attempting to inline a block from a replaced/erased op");
1685+
assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1686+
"attempting to inline a block into a replaced/erased op");
16791687
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1680-
#endif // NDEBUG
16811688
// The source block will be deleted, so it should not have any users (i.e.,
16821689
// there should be no predecessors).
16831690
assert(llvm::all_of(source->getUsers(), opIgnored) &&
16841691
"expected 'source' to have no predecessors");
1692+
#endif // NDEBUG
16851693

16861694
impl->notifyBlockBeingInlined(dest, source, before);
16871695
for (auto it : llvm::zip(source->getArguments(), argValues))
@@ -1691,13 +1699,17 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
16911699
}
16921700

16931701
void ConversionPatternRewriter::startOpModification(Operation *op) {
1702+
assert(!impl->wasOpReplaced(op) &&
1703+
"attempting to modify a replaced/erased op");
16941704
#ifndef NDEBUG
16951705
impl->pendingRootUpdates.insert(op);
16961706
#endif
16971707
impl->appendRewrite<ModifyOperationRewrite>(op);
16981708
}
16991709

17001710
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1711+
assert(!impl->wasOpReplaced(op) &&
1712+
"attempting to modify a replaced/erased op");
17011713
PatternRewriter::finalizeOpModification(op);
17021714
// There is nothing to do here, we only need to track the operation at the
17031715
// start of the update.
@@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op,
19121924

19131925
// If this operation is recursively legal, mark its children as ignored so
19141926
// that we don't consider them for legalization.
1915-
if (legalityInfo->isRecursivelyLegal)
1916-
rewriter.getImpl().markNestedOpsIgnored(op);
1927+
if (legalityInfo->isRecursivelyLegal) {
1928+
op->walk([&](Operation *nested) {
1929+
if (op != nested)
1930+
rewriter.getImpl().ignoredOps.insert(nested);
1931+
});
1932+
}
1933+
19171934
return success();
19181935
}
19191936

mlir/test/lib/Dialect/Test/TestPatterns.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps
17681768
rewriter.inlineBlockBefore(&innerBlock, op);
17691769
rewriter.eraseOp(innerTerminator);
17701770
rewriter.eraseOp(op);
1771-
rewriter.modifyOpInPlace(op, [] {});
17721771
return success();
17731772
}
17741773
};

0 commit comments

Comments
 (0)