Skip to content

Commit cbc69f9

Browse files
committed
[mlir][emitc][cf] add 'cf.switch' support in CppEmitter
1 parent 2d36550 commit cbc69f9

File tree

2 files changed

+103
-3
lines changed

2 files changed

+103
-3
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,65 @@ static LogicalResult printOperation(CppEmitter &emitter,
579579
return success();
580580
}
581581

582+
static LogicalResult printOperation(CppEmitter &emitter,
583+
cf::SwitchOp switchOp) {
584+
raw_indented_ostream &os = emitter.ostream();
585+
auto iteratorCaseValues = (*switchOp.getCaseValues()).begin();
586+
auto iteratorCaseValuesEnd = (*switchOp.getCaseValues()).end();
587+
size_t caseIndex = 0;
588+
589+
os << "\nswitch(" << emitter.getOrCreateName(switchOp.getFlag()) << ") {";
590+
591+
for (const auto caseBlock : switchOp.getCaseDestinations()) {
592+
if (iteratorCaseValues == iteratorCaseValuesEnd)
593+
return switchOp.emitOpError("case's value is absent for case block");
594+
595+
os << "\ncase "
596+
<< "(" << *(iteratorCaseValues++) << ")"
597+
<< ": {\n";
598+
os.indent();
599+
600+
for (auto pair : llvm::zip(switchOp.getCaseOperands(caseIndex++),
601+
caseBlock->getArguments())) {
602+
Value &operand = std::get<0>(pair);
603+
BlockArgument &argument = std::get<1>(pair);
604+
os << emitter.getOrCreateName(argument) << " = "
605+
<< emitter.getOrCreateName(operand) << ";\n";
606+
}
607+
608+
os << "goto ";
609+
610+
if (!(emitter.hasBlockLabel(*caseBlock)))
611+
return switchOp.emitOpError("unable to find label for case block");
612+
os << emitter.getOrCreateName(*caseBlock) << ";\n";
613+
614+
os.unindent() << "}";
615+
}
616+
617+
os << "\ndefault: {\n";
618+
os.indent();
619+
620+
for (auto pair :
621+
llvm::zip(switchOp.getDefaultOperands(),
622+
(switchOp.getDefaultDestination())->getArguments())) {
623+
Value &operand = std::get<0>(pair);
624+
BlockArgument &argument = std::get<1>(pair);
625+
os << emitter.getOrCreateName(argument) << " = "
626+
<< emitter.getOrCreateName(operand) << ";\n";
627+
}
628+
629+
os << "goto ";
630+
631+
if (!(emitter.hasBlockLabel(*switchOp.getDefaultDestination())))
632+
return switchOp.emitOpError("unable to find label for default block");
633+
os << emitter.getOrCreateName(*switchOp.getDefaultDestination()) << ";\n";
634+
635+
os.unindent() << "}\n";
636+
os << "}\n";
637+
638+
return success();
639+
}
640+
582641
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
583642
StringRef callee) {
584643
if (failed(emitter.emitAssignPrefix(*callOp)))
@@ -997,8 +1056,8 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
9971056
// When generating code for an emitc.for and emitc.verbatim op, printing a
9981057
// trailing semicolon is handled within the printOperation function.
9991058
bool trailingSemicolon =
1000-
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
1001-
emitc::IfOp, emitc::VerbatimOp>(op);
1059+
!isa<cf::CondBranchOp, cf::SwitchOp, emitc::DeclareFuncOp,
1060+
emitc::ForOp, emitc::IfOp, emitc::VerbatimOp>(op);
10021061

10031062
if (failed(emitter.emitOperation(
10041063
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1496,7 +1555,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14961555
// Builtin ops.
14971556
.Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
14981557
// CF ops.
1499-
.Case<cf::BranchOp, cf::CondBranchOp>(
1558+
.Case<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
15001559
[&](auto op) { return printOperation(*this, op); })
15011560
// EmitC ops.
15021561
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,

mlir/test/Target/Cpp/switch.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
2+
3+
func.func @switch_func(%a: i32, %b: i32, %c: i32) -> () {
4+
cf.switch %b : i32, [
5+
default: ^bb1(%a : i32),
6+
42: ^bb1(%b : i32),
7+
43: ^bb2(%c : i32),
8+
44: ^bb3(%c : i32)
9+
]
10+
11+
^bb1(%x1 : i32) :
12+
%y1 = "emitc.add" (%x1, %x1) : (i32, i32) -> i32
13+
return
14+
15+
^bb2(%x2 : i32) :
16+
%y2 = "emitc.sub" (%x2, %x2) : (i32, i32) -> i32
17+
return
18+
19+
^bb3(%x3 : i32) :
20+
%y3 = "emitc.mul" (%x3, %x3) : (i32, i32) -> i32
21+
return
22+
}
23+
// CHECK: void switch_func(int32_t [[V0:[^ ]*]], int32_t [[V1:[^ ]*]], int32_t [[V2:[^ ]*]]) {
24+
// CHECK: switch([[V1:[^ ]*]]) {
25+
// CHECK-NEXT: case (42): {
26+
// CHECK-NEXT: v7 = v2;
27+
// CHECK-NEXT: goto label2;
28+
// CHECK-NEXT: }
29+
// CHECK-NEXT: case (43): {
30+
// CHECK-NEXT: v8 = v3;
31+
// CHECK-NEXT: goto label3;
32+
// CHECK-NEXT: }
33+
// CHECK-NEXT: case (44): {
34+
// CHECK-NEXT: v9 = v3;
35+
// CHECK-NEXT: goto label4;
36+
// CHECK-NEXT: }
37+
// CHECK-NEXT: default: {
38+
// CHECK-NEXT: v7 = v1;
39+
// CHECK-NEXT: goto label2;
40+
// CHECK-NEXT: }
41+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)