@@ -66,11 +66,12 @@ class SCFWhileLoop {
66
66
SCFWhileLoop (cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
67
67
mlir::ConversionPatternRewriter *rewriter)
68
68
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
69
- void transferToSCFWhileOp ();
69
+ mlir::scf::WhileOp transferToSCFWhileOp ();
70
70
71
71
private:
72
72
cir::WhileOp whileOp;
73
73
cir::WhileOp::Adaptor adaptor;
74
+ mlir::scf::WhileOp scfWhileOp;
74
75
mlir::ConversionPatternRewriter *rewriter;
75
76
};
76
77
@@ -356,7 +357,7 @@ void SCFLoop::transformToSCFWhileOp() {
356
357
scfWhileOp.getAfterBody ()->end ());
357
358
}
358
359
359
- void SCFWhileLoop::transferToSCFWhileOp () {
360
+ mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp () {
360
361
auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
361
362
whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
362
363
rewriter->createBlock (&scfWhileOp.getBefore ());
@@ -367,6 +368,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
367
368
rewriter->inlineBlockBefore (&whileOp.getBody ().front (),
368
369
scfWhileOp.getAfterBody (),
369
370
scfWhileOp.getAfterBody ()->end ());
371
+ return scfWhileOp;
370
372
}
371
373
372
374
void SCFDoLoop::transferToSCFWhileOp () {
@@ -412,14 +414,53 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
412
414
};
413
415
414
416
class CIRWhileOpLowering : public mlir ::OpConversionPattern<cir::WhileOp> {
417
+ void rewriteContinue (mlir::scf::WhileOp whileOp,
418
+ mlir::ConversionPatternRewriter &rewriter) const {
419
+ // Collect all ContinueOp inside this while.
420
+ llvm::SmallVector<cir::ContinueOp> continues;
421
+ whileOp->walk ([&](mlir::Operation *op) {
422
+ if (auto continueOp = dyn_cast<ContinueOp>(op))
423
+ continues.push_back (continueOp);
424
+ });
425
+
426
+ if (continues.empty ())
427
+ return ;
428
+
429
+ for (auto continueOp : continues) {
430
+ // When the break is under an IfOp, a direct replacement of `scf.yield`
431
+ // won't work: the yield would jump out of that IfOp instead. We might
432
+ // need to change the whileOp itself to achieve the same effect.
433
+ for (mlir::Operation *parent = continueOp->getParentOp ();
434
+ parent != whileOp; parent = parent->getParentOp ()) {
435
+ if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
436
+ llvm_unreachable (" NYI" );
437
+ }
438
+
439
+ // Operations after this break has to be removed.
440
+ for (mlir::Operation *runner = continueOp->getNextNode (); runner;) {
441
+ mlir::Operation *next = runner->getNextNode ();
442
+ runner->erase ();
443
+ runner = next;
444
+ }
445
+
446
+ // Blocks after this break also has to be removed.
447
+ for (mlir::Block *block = continueOp->getBlock ()->getNextNode (); block;) {
448
+ mlir::Block *next = block->getNextNode ();
449
+ block->erase ();
450
+ block = next;
451
+ }
452
+ }
453
+ }
454
+
415
455
public:
416
456
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
417
457
418
458
mlir::LogicalResult
419
459
matchAndRewrite (cir::WhileOp op, OpAdaptor adaptor,
420
460
mlir::ConversionPatternRewriter &rewriter) const override {
421
461
SCFWhileLoop loop (op, adaptor, &rewriter);
422
- loop.transferToSCFWhileOp ();
462
+ auto whileOp = loop.transferToSCFWhileOp ();
463
+ rewriteContinue (whileOp, rewriter);
423
464
rewriter.eraseOp (op);
424
465
return mlir::success ();
425
466
}
0 commit comments