Skip to content

Commit 51e17c0

Browse files
[CIR][ThroughMLIR] Handle ContinueOp directly under a WhileOp (#1669)
Currently we can't handle continues nested under `IfOp`, because if we replace it with a yield, then it only breaks out of that `if`-statement, rather than continuing the whole loop. Perhaps that should be done by changing the whole structure of the while loop. Co-authored-by: Yue Huang <[email protected]>
1 parent bc6df70 commit 51e17c0

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ class SCFWhileLoop {
6666
SCFWhileLoop(cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
6767
mlir::ConversionPatternRewriter *rewriter)
6868
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
69-
void transferToSCFWhileOp();
69+
mlir::scf::WhileOp transferToSCFWhileOp();
7070

7171
private:
7272
cir::WhileOp whileOp;
7373
cir::WhileOp::Adaptor adaptor;
74+
mlir::scf::WhileOp scfWhileOp;
7475
mlir::ConversionPatternRewriter *rewriter;
7576
};
7677

@@ -356,7 +357,7 @@ void SCFLoop::transformToSCFWhileOp() {
356357
scfWhileOp.getAfterBody()->end());
357358
}
358359

359-
void SCFWhileLoop::transferToSCFWhileOp() {
360+
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() {
360361
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
361362
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
362363
rewriter->createBlock(&scfWhileOp.getBefore());
@@ -367,6 +368,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
367368
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
368369
scfWhileOp.getAfterBody(),
369370
scfWhileOp.getAfterBody()->end());
371+
return scfWhileOp;
370372
}
371373

372374
void SCFDoLoop::transferToSCFWhileOp() {
@@ -412,14 +414,53 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
412414
};
413415

414416
class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
417+
void rewriteContinue(mlir::scf::WhileOp whileOp,
418+
mlir::ConversionPatternRewriter &rewriter) const {
419+
// Collect all ContinueOp inside this while.
420+
llvm::SmallVector<cir::ContinueOp> continues;
421+
whileOp->walk([&](mlir::Operation *op) {
422+
if (auto continueOp = dyn_cast<ContinueOp>(op))
423+
continues.push_back(continueOp);
424+
});
425+
426+
if (continues.empty())
427+
return;
428+
429+
for (auto continueOp : continues) {
430+
// When the break is under an IfOp, a direct replacement of `scf.yield`
431+
// won't work: the yield would jump out of that IfOp instead. We might
432+
// need to change the whileOp itself to achieve the same effect.
433+
for (mlir::Operation *parent = continueOp->getParentOp();
434+
parent != whileOp; parent = parent->getParentOp()) {
435+
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
436+
llvm_unreachable("NYI");
437+
}
438+
439+
// Operations after this break has to be removed.
440+
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
441+
mlir::Operation *next = runner->getNextNode();
442+
runner->erase();
443+
runner = next;
444+
}
445+
446+
// Blocks after this break also has to be removed.
447+
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
448+
mlir::Block *next = block->getNextNode();
449+
block->erase();
450+
block = next;
451+
}
452+
}
453+
}
454+
415455
public:
416456
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
417457

418458
mlir::LogicalResult
419459
matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor,
420460
mlir::ConversionPatternRewriter &rewriter) const override {
421461
SCFWhileLoop loop(op, adaptor, &rewriter);
422-
loop.transferToSCFWhileOp();
462+
auto whileOp = loop.transferToSCFWhileOp();
463+
rewriteContinue(whileOp, rewriter);
423464
rewriter.eraseOp(op);
424465
return mlir::success();
425466
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void for_with_break() {
5+
int i = 0;
6+
while (i < 100) {
7+
i++;
8+
continue;
9+
i++;
10+
}
11+
// Only the first `i++` will be emitted.
12+
13+
// CHECK: scf.while : () -> () {
14+
// CHECK: %[[TMP0:.+]] = memref.load %alloca[]
15+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
16+
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %[[TMP0]], %[[HUNDRED]]
17+
// CHECK: scf.condition(%[[TMP1]])
18+
// CHECK: } do {
19+
// CHECK: memref.alloca_scope {
20+
// CHECK: %[[TMP2:.+]] = memref.load %alloca[]
21+
// CHECK: %[[ONE:.+]] = arith.constant 1
22+
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
23+
// CHECK: memref.store %[[TMP3]], %alloca[]
24+
// CHECK: }
25+
// CHECK: scf.yield
26+
// CHECK: }
27+
}

0 commit comments

Comments
 (0)