Skip to content

Commit e7d1c50

Browse files
committed
[mlir][emitc] Add conversion from 'scf::index_switch' to 'emitc::switch'
1 parent 07f8db0 commit e7d1c50

File tree

2 files changed

+172
-1
lines changed

2 files changed

+172
-1
lines changed

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,59 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
177177
return success();
178178
}
179179

180+
// Lower scf::index_switch to emitc::switch, implementing result values as
181+
// emitc::variable's updated within the case and default regions.
182+
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
183+
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
184+
185+
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
186+
PatternRewriter &rewriter) const override;
187+
};
188+
189+
LogicalResult
190+
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
191+
PatternRewriter &rewriter) const {
192+
Location loc = indexSwitchOp.getLoc();
193+
194+
// Create an emitc::variable op for each result. These variables will be
195+
// assigned to by emitc::assign ops within the case and default regions.
196+
SmallVector<Value> resultVariables =
197+
createVariablesForResults(indexSwitchOp, rewriter);
198+
199+
// Utility function to lower the contents of an scf::index_switch regions to
200+
// an emitc::switch regions. The contents of the scf::index_switch regions is
201+
// moved into the respective emitc::switch regions, but the scf::yield is
202+
// replaced not only with an emitc::yield, but also with a sequence of
203+
// emitc::assign ops that set the yielded values into the result variables.
204+
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
205+
Region &loweredRegion) {
206+
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
207+
Operation *terminator = loweredRegion.back().getTerminator();
208+
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
209+
};
210+
211+
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
212+
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
213+
indexSwitchOp.getNumCases());
214+
215+
// Lowering all case regions.
216+
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
217+
loweredSwitch.getCaseRegions())) {
218+
lowerRegion(std::get<0>(pair), std::get<1>(pair));
219+
}
220+
221+
// Lowering default region.
222+
lowerRegion(indexSwitchOp.getDefaultRegion(),
223+
loweredSwitch.getDefaultRegion());
224+
225+
rewriter.replaceOp(indexSwitchOp, resultVariables);
226+
return success();
227+
}
228+
180229
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
181230
patterns.add<ForLowering>(patterns.getContext());
182231
patterns.add<IfLowering>(patterns.getContext());
232+
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
183233
}
184234

185235
void SCFToEmitCPass::runOnOperation() {
@@ -188,7 +238,7 @@ void SCFToEmitCPass::runOnOperation() {
188238

189239
// Configure conversion to lower out SCF operations.
190240
ConversionTarget target(getContext());
191-
target.addIllegalOp<scf::ForOp, scf::IfOp>();
241+
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
192242
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
193243
if (failed(
194244
applyPartialConversion(getOperation(), target, std::move(patterns))))
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @switch_no_result(
4+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
5+
// CHECK: emitc.switch %[[VAL_0]]
6+
// CHECK: case 2: {
7+
// CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
8+
// CHECK: emitc.yield
9+
// CHECK: }
10+
// CHECK: case 5: {
11+
// CHECK: %[[VAL_2:.*]] = arith.constant 20 : i32
12+
// CHECK: emitc.yield
13+
// CHECK: }
14+
// CHECK: default {
15+
// CHECK: %[[VAL_3:.*]] = arith.constant 30 : i32
16+
// CHECK: emitc.yield
17+
// CHECK: }
18+
// CHECK: return
19+
// CHECK: }
20+
func.func @switch_no_result(%arg0 : index) {
21+
scf.index_switch %arg0
22+
case 2 {
23+
%1 = arith.constant 10 : i32
24+
scf.yield
25+
}
26+
case 5 {
27+
%2 = arith.constant 20 : i32
28+
scf.yield
29+
}
30+
default {
31+
%3 = arith.constant 30 : i32
32+
scf.yield
33+
}
34+
return
35+
}
36+
37+
// CHECK-LABEL: func.func @switch_one_result(
38+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
39+
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> i32
40+
// CHECK: emitc.switch %[[VAL_0]]
41+
// CHECK: case 2: {
42+
// CHECK: %[[VAL_2:.*]] = arith.constant 10 : i32
43+
// CHECK: emitc.assign %[[VAL_2]] : i32 to %[[VAL_1]] : i32
44+
// CHECK: emitc.yield
45+
// CHECK: }
46+
// CHECK: case 5: {
47+
// CHECK: %[[VAL_3:.*]] = arith.constant 20 : i32
48+
// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_1]] : i32
49+
// CHECK: emitc.yield
50+
// CHECK: }
51+
// CHECK: default {
52+
// CHECK: %[[VAL_4:.*]] = arith.constant 30 : i32
53+
// CHECK: emitc.assign %[[VAL_4]] : i32 to %[[VAL_1]] : i32
54+
// CHECK: emitc.yield
55+
// CHECK: }
56+
// CHECK: return
57+
// CHECK: }
58+
func.func @switch_one_result(%arg0 : index) {
59+
%0 = scf.index_switch %arg0 -> i32
60+
case 2 {
61+
%1 = arith.constant 10 : i32
62+
scf.yield %1 : i32
63+
}
64+
case 5 {
65+
%2 = arith.constant 20 : i32
66+
scf.yield %2 : i32
67+
}
68+
default {
69+
%3 = arith.constant 30 : i32
70+
scf.yield %3 : i32
71+
}
72+
return
73+
}
74+
75+
// CHECK-LABEL: func.func @switch_two_results(
76+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
77+
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> i32
78+
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> f32
79+
// CHECK: emitc.switch %[[VAL_0]]
80+
// CHECK: case 2: {
81+
// CHECK: %[[VAL_3:.*]] = arith.constant 10 : i32
82+
// CHECK: %[[VAL_4:.*]] = arith.constant 1.200000e+00 : f32
83+
// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_1]] : i32
84+
// CHECK: emitc.assign %[[VAL_4]] : f32 to %[[VAL_2]] : f32
85+
// CHECK: emitc.yield
86+
// CHECK: }
87+
// CHECK: case 5: {
88+
// CHECK: %[[VAL_5:.*]] = arith.constant 20 : i32
89+
// CHECK: %[[VAL_6:.*]] = arith.constant 2.400000e+00 : f32
90+
// CHECK: emitc.assign %[[VAL_5]] : i32 to %[[VAL_1]] : i32
91+
// CHECK: emitc.assign %[[VAL_6]] : f32 to %[[VAL_2]] : f32
92+
// CHECK: emitc.yield
93+
// CHECK: }
94+
// CHECK: default {
95+
// CHECK: %[[VAL_7:.*]] = arith.constant 30 : i32
96+
// CHECK: %[[VAL_8:.*]] = arith.constant 3.600000e+00 : f32
97+
// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_1]] : i32
98+
// CHECK: emitc.assign %[[VAL_8]] : f32 to %[[VAL_2]] : f32
99+
// CHECK: emitc.yield
100+
// CHECK: }
101+
// CHECK: return
102+
// CHECK: }
103+
func.func @switch_two_results(%arg0 : index) {
104+
%0, %1 = scf.index_switch %arg0 -> i32, f32
105+
case 2 {
106+
%2 = arith.constant 10 : i32
107+
%3 = arith.constant 1.2 : f32
108+
scf.yield %2, %3 : i32, f32
109+
}
110+
case 5 {
111+
%4 = arith.constant 20 : i32
112+
%5 = arith.constant 2.4 : f32
113+
scf.yield %4, %5 : i32, f32
114+
}
115+
default {
116+
%6 = arith.constant 30 : i32
117+
%7 = arith.constant 3.6 : f32
118+
scf.yield %6, %7 : i32, f32
119+
}
120+
return
121+
}

0 commit comments

Comments
 (0)