Skip to content

[mlir][Transforms] Add 1:N support to replaceUsesOfBlockArgument #145171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/matthias-springer/simplify_replace_op
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

This commit adds 1:N support to ConversionPatternRewriter::replaceUsesOfBlockArgument. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet.

This commit also reuses replaceUsesOfBlockArgument in the implementation of applySignatureConversion. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the applySignatureConversion implementation into a state where it works both with and without rollbacks. To that end, applySignatureConversion should not directly access the mapping.

Depends on #145155.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds 1:N support to ConversionPatternRewriter::replaceUsesOfBlockArgument. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet.

This commit also reuses replaceUsesOfBlockArgument in the implementation of applySignatureConversion. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the applySignatureConversion implementation into a state where it works both with and without rollbacks. To that end, applySignatureConversion should not directly access the mapping.

Depends on #145155.


Full diff: https://github.com/llvm/llvm-project/pull/145171.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3-2)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-15)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+24-7)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+29-22)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5a5f116073a9a..81858812d2623 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with value `to`.
-  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
+  /// Replace all the uses of the block argument `from` with `to`. This
+  /// function supports both 1:1 and 1:N replacements.
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..9e8e746507557 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
     Type resTy = typeConverter.convertType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
-    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
     rewriter.replaceUsesOfBlockArgument(arg, valueArg);
   }
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 774d58973eb91..9cb6f2ba1eaae 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
+  /// Replace the given block argument with the given values. The specified
+  /// converter is used to build materializations (if necessary).
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
+                                  const TypeConverter *converter);
+
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
 
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
-          MaterializationKind::Source,
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
-          /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      Value mat =
+          buildUnresolvedMaterialization(
+              MaterializationKind::Source,
+              OpBuilder::InsertPoint(newBlock, newBlock->begin()),
+              origArg.getLoc(),
+              /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+              .front();
+      replaceUsesOfBlockArgument(origArg, mat, converter);
       continue;
     }
 
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValues);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
+                                 converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
-    mapping.map(origArg, std::move(replArgVals));
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+    replaceUsesOfBlockArgument(origArg, replArgs, converter);
   }
 
   appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
+void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
+    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
+  mapping.map(from, to);
+}
+
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   assert(!wasOpReplaced(block->getParentOp()) &&
          "attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           Value to) {
+                                                           ValueRange to) {
   LLVM_DEBUG({
     impl->logger.startLine() << "** Replace Argument : '" << from << "'";
     if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
       impl->logger.getOStream() << " (unlinked block)\n";
     }
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
-                                              impl->currentTypeConverter);
-  impl->mapping.map(from, to);
+  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 204c8c1456826..79518b04e7158 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
 // -----
 
 // CHECK-LABEL: @undo_block_arg_replace
+// expected-remark@+1{{applyPartialConversion failed}}
+module {
 func.func @undo_block_arg_replace() {
-  // expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
-  "test.undo_block_arg_replace"() ({
-  ^bb0(%arg0: i32):
-    // CHECK: ^bb0(%[[ARG:.*]]: i32):
-    // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+  // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
     "test.return"(%arg0) : (i32) -> ()
-  }) : () -> ()
-  // expected-remark@+1 {{op 'func.return' is not legalizable}}
+  }) {trigger_rollback} : () -> ()
   return
 }
+}
+
+// -----
+
+// CHECK-LABEL: @replace_block_arg_1_to_n
+func.func @replace_block_arg_1_to_n() {
+  // CHECK: "test.block_arg_replace"
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.return"(%arg0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
 
 // -----
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..588e529665dd1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the undo mechanism when replacing the uses of a
-/// block argument.
-struct TestUndoBlockArgReplace : public ConversionPattern {
-  TestUndoBlockArgReplace(MLIRContext *ctx)
-      : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+struct TestBlockArgReplace : public ConversionPattern {
+  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
+                          ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    auto illegalOp =
-        rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    // Replace the first block argument with 2x the second block argument.
+    Value repl = op->getRegion(0).getArgument(1);
     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        illegalOp->getResult(0));
-    rewriter.modifyOpInPlace(op, [] {});
+                                        {repl, repl});
+    rewriter.modifyOpInPlace(op, [&] {
+      // If the "trigger_rollback" attribute is set, keep the op illegal, so
+      // that a rollback is triggered.
+      if (!op->hasAttr("trigger_rollback"))
+        op->setAttr("is_legal", rewriter.getUnitAttr());
+    });
     return success();
   }
 };
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns
-        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
-             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-             TestNonRootReplacement, TestBoundedRecursiveRewrite,
-             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-             TestUndoPropertiesModification, TestEraseOp,
-             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<
+        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+        TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
+        TestUpdateConsumerType, TestNonRootReplacement,
+        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
+        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+        TestUndoPropertiesModification, TestEraseOp,
+        TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
-        &getContext(), converter);
+                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
+                 TestBlockArgReplace>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
     });
     target.addDynamicallyLegalOp<func::CallOp>(
         [&](func::CallOp op) { return converter.isLegal(op); });
+    target.addDynamicallyLegalOp(
+        OperationName("test.block_arg_replace", &getContext()),
+        [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,
     // which was not added to target intentionally to test

@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds 1:N support to ConversionPatternRewriter::replaceUsesOfBlockArgument. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet.

This commit also reuses replaceUsesOfBlockArgument in the implementation of applySignatureConversion. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the applySignatureConversion implementation into a state where it works both with and without rollbacks. To that end, applySignatureConversion should not directly access the mapping.

Depends on #145155.


Full diff: https://github.com/llvm/llvm-project/pull/145171.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3-2)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-15)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+24-7)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+29-22)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5a5f116073a9a..81858812d2623 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with value `to`.
-  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
+  /// Replace all the uses of the block argument `from` with `to`. This
+  /// function supports both 1:1 and 1:N replacements.
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..9e8e746507557 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
     Type resTy = typeConverter.convertType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
-    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
     rewriter.replaceUsesOfBlockArgument(arg, valueArg);
   }
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 774d58973eb91..9cb6f2ba1eaae 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
+  /// Replace the given block argument with the given values. The specified
+  /// converter is used to build materializations (if necessary).
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
+                                  const TypeConverter *converter);
+
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
 
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
-          MaterializationKind::Source,
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
-          /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      Value mat =
+          buildUnresolvedMaterialization(
+              MaterializationKind::Source,
+              OpBuilder::InsertPoint(newBlock, newBlock->begin()),
+              origArg.getLoc(),
+              /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+              .front();
+      replaceUsesOfBlockArgument(origArg, mat, converter);
       continue;
     }
 
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValues);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
+                                 converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
-    mapping.map(origArg, std::move(replArgVals));
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+    replaceUsesOfBlockArgument(origArg, replArgs, converter);
   }
 
   appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
+void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
+    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
+  mapping.map(from, to);
+}
+
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   assert(!wasOpReplaced(block->getParentOp()) &&
          "attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           Value to) {
+                                                           ValueRange to) {
   LLVM_DEBUG({
     impl->logger.startLine() << "** Replace Argument : '" << from << "'";
     if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
       impl->logger.getOStream() << " (unlinked block)\n";
     }
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
-                                              impl->currentTypeConverter);
-  impl->mapping.map(from, to);
+  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 204c8c1456826..79518b04e7158 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
 // -----
 
 // CHECK-LABEL: @undo_block_arg_replace
+// expected-remark@+1{{applyPartialConversion failed}}
+module {
 func.func @undo_block_arg_replace() {
-  // expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
-  "test.undo_block_arg_replace"() ({
-  ^bb0(%arg0: i32):
-    // CHECK: ^bb0(%[[ARG:.*]]: i32):
-    // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+  // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
     "test.return"(%arg0) : (i32) -> ()
-  }) : () -> ()
-  // expected-remark@+1 {{op 'func.return' is not legalizable}}
+  }) {trigger_rollback} : () -> ()
   return
 }
+}
+
+// -----
+
+// CHECK-LABEL: @replace_block_arg_1_to_n
+func.func @replace_block_arg_1_to_n() {
+  // CHECK: "test.block_arg_replace"
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.return"(%arg0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
 
 // -----
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..588e529665dd1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the undo mechanism when replacing the uses of a
-/// block argument.
-struct TestUndoBlockArgReplace : public ConversionPattern {
-  TestUndoBlockArgReplace(MLIRContext *ctx)
-      : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+struct TestBlockArgReplace : public ConversionPattern {
+  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
+                          ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    auto illegalOp =
-        rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    // Replace the first block argument with 2x the second block argument.
+    Value repl = op->getRegion(0).getArgument(1);
     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        illegalOp->getResult(0));
-    rewriter.modifyOpInPlace(op, [] {});
+                                        {repl, repl});
+    rewriter.modifyOpInPlace(op, [&] {
+      // If the "trigger_rollback" attribute is set, keep the op illegal, so
+      // that a rollback is triggered.
+      if (!op->hasAttr("trigger_rollback"))
+        op->setAttr("is_legal", rewriter.getUnitAttr());
+    });
     return success();
   }
 };
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns
-        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
-             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-             TestNonRootReplacement, TestBoundedRecursiveRewrite,
-             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-             TestUndoPropertiesModification, TestEraseOp,
-             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<
+        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+        TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
+        TestUpdateConsumerType, TestNonRootReplacement,
+        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
+        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+        TestUndoPropertiesModification, TestEraseOp,
+        TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
-        &getContext(), converter);
+                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
+                 TestBlockArgReplace>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
     });
     target.addDynamicallyLegalOp<func::CallOp>(
         [&](func::CallOp op) { return converter.isLegal(op); });
+    target.addDynamicallyLegalOp(
+        OperationName("test.block_arg_replace", &getContext()),
+        [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,
     // which was not added to target intentionally to test

OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it a bug here that this step did not update mapping? (and in this PR now it does via replaceUsesOfBlockArgument)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants