@@ -177,9 +177,59 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
177177 return success ();
178178}
179179
180+ // Lower scf::index_switch to emitc::switch, implementing result values as
181+ // emitc::variable's updated within the case and default regions.
182+ struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
183+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
184+
185+ LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
186+ PatternRewriter &rewriter) const override ;
187+ };
188+
189+ LogicalResult
190+ IndexSwitchOpLowering::matchAndRewrite (IndexSwitchOp indexSwitchOp,
191+ PatternRewriter &rewriter) const {
192+ Location loc = indexSwitchOp.getLoc ();
193+
194+ // Create an emitc::variable op for each result. These variables will be
195+ // assigned to by emitc::assign ops within the case and default regions.
196+ SmallVector<Value> resultVariables =
197+ createVariablesForResults (indexSwitchOp, rewriter);
198+
199+ // Utility function to lower the contents of an scf::index_switch regions to
200+ // an emitc::switch regions. The contents of the scf::index_switch regions is
201+ // moved into the respective emitc::switch regions, but the scf::yield is
202+ // replaced not only with an emitc::yield, but also with a sequence of
203+ // emitc::assign ops that set the yielded values into the result variables.
204+ auto lowerRegion = [&resultVariables, &rewriter](Region ®ion,
205+ Region &loweredRegion) {
206+ rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
207+ Operation *terminator = loweredRegion.back ().getTerminator ();
208+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
209+ };
210+
211+ auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
212+ loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
213+ indexSwitchOp.getNumCases ());
214+
215+ // Lowering all case regions.
216+ for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
217+ loweredSwitch.getCaseRegions ())) {
218+ lowerRegion (std::get<0 >(pair), std::get<1 >(pair));
219+ }
220+
221+ // Lowering default region.
222+ lowerRegion (indexSwitchOp.getDefaultRegion (),
223+ loweredSwitch.getDefaultRegion ());
224+
225+ rewriter.replaceOp (indexSwitchOp, resultVariables);
226+ return success ();
227+ }
228+
180229void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
181230 patterns.add <ForLowering>(patterns.getContext ());
182231 patterns.add <IfLowering>(patterns.getContext ());
232+ patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
183233}
184234
185235void SCFToEmitCPass::runOnOperation () {
@@ -188,7 +238,7 @@ void SCFToEmitCPass::runOnOperation() {
188238
189239 // Configure conversion to lower out SCF operations.
190240 ConversionTarget target (getContext ());
191- target.addIllegalOp <scf::ForOp, scf::IfOp>();
241+ target.addIllegalOp <scf::ForOp, scf::IfOp, scf::IndexSwitchOp >();
192242 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
193243 if (failed (
194244 applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments