Skip to content

Commit 9cccc34

Browse files
Add early mux hoisting optimization
1 parent 3186676 commit 9cccc34

File tree

7 files changed

+153
-17
lines changed

7 files changed

+153
-17
lines changed

zirgen/Dialect/ZHLT/Transforms/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"ElideRedundantMembers.cpp",
3737
"GenerateSteps.cpp",
3838
"HoistAllocs.cpp",
39+
"HoistCommonMuxCode.cpp",
3940
"LowerAssumeRange.cpp",
4041
"LowerFuncs.cpp",
4142
"OptimizeParWitgen.cpp",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2024 RISC Zero, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "zirgen/Dialect/ZHLT/IR/ZHLT.h"
16+
#include "zirgen/Dialect/ZHLT/Transforms/PassDetail.h"
17+
#include "llvm/Support/Debug.h"
18+
#include "llvm/ADT/Statistic.h"
19+
20+
#define DEBUG_TYPE "mux-hoisting"
21+
22+
using namespace mlir;
23+
using namespace zirgen::ZStruct;
24+
25+
namespace zirgen::Zhlt {
26+
27+
namespace {
28+
29+
bool isHoistable(Operation* op) {
30+
// Check if all of the operation's operands are defined outside of the mux. If
31+
// they are, then we can probably hoist -- but make sure not to reorder side
32+
// effects, and never hoist any ops with regions, AliasLayoutOps or block
33+
// terminators.
34+
return op->getNumRegions() == 0 &&
35+
!isa<LoadOp>(op) && // could be more precise by checking for writes within the block
36+
!isa<AliasLayoutOp>(op) &&
37+
!op->hasTrait<OpTrait::IsTerminator>() &&
38+
llvm::all_of(op->getOperands(), [=](Value value) {
39+
return value.getParentRegion() != op->getParentRegion();
40+
});
41+
}
42+
43+
bool compare(Operation* op1, Operation* op2) {
44+
if (op1->getName() != op2->getName() ||
45+
op1->getNumOperands() != op2->getNumOperands() ||
46+
op1->getAttrs().size() != op2->getAttrs().size())
47+
return false;
48+
for (auto [opn1, opn2] : llvm::zip(op1->getOperands(), op2->getOperands())) {
49+
if (opn1 != opn2)
50+
return false;
51+
}
52+
for (auto [attr1, attr2] : llvm::zip(op1->getAttrs(), op2->getAttrs())) {
53+
if (attr1 != attr2)
54+
return false;
55+
}
56+
return true;
57+
}
58+
59+
} // namespace
60+
61+
// For structure-like components, if two members are equal in the PackOp at the
62+
// end of the constructor, those members will ultimately be equal in all other
63+
// situations, such as when reconstructing an instance from a back and when
64+
// zero-initializing (trivially, since both members are zeroed).
65+
struct HoistCommonMuxCodePass : public HoistCommonMuxCodeBase<HoistCommonMuxCodePass> {
66+
HoistCommonMuxCodePass() = default;
67+
HoistCommonMuxCodePass(const HoistCommonMuxCodePass& pass) {}
68+
69+
void runOnOperation() override {
70+
getOperation().walk<WalkOrder::PostOrder>([&](SwitchOp mux) {
71+
// If code in the mux is shared by multiple but not all mux arms, then
72+
// hoisting it reduces code size but increases execution cost. Since we
73+
// want code shared by all arms, search the first mux arm for hoistable
74+
// operations, and then search the other mux arms for matching operations.
75+
auto it = mux.getRegion(0).op_begin();
76+
while (it != mux->getRegion(0).op_end()) {
77+
Operation& op = *(it++);
78+
if (isHoistable(&op)) {
79+
SmallVector<Operation*> toHoist;
80+
toHoist.reserve(mux.getArms().size());
81+
toHoist.push_back(&op);
82+
83+
bool hoistable = llvm::all_of(mux.getRegions(), [&](Region* region) {
84+
if (region->getRegionNumber() == 0)
85+
return true;
86+
87+
return llvm::any_of(region->getOps(), [&](Operation& op2) {
88+
if (compare(&op, &op2)) {
89+
toHoist.push_back(&op2);
90+
return true;
91+
}
92+
return false;
93+
});
94+
});
95+
96+
if (hoistable) {
97+
LLVM_DEBUG(llvm::dbgs() << "hoist: " << *toHoist[0] << "\n");
98+
toHoist[0]->moveBefore(mux);
99+
for (size_t i = 1; i < toHoist.size(); ++i) {
100+
toHoist[i]->replaceAllUsesWith(toHoist[0]->getResults());
101+
toHoist[i]->erase();
102+
++opsDeleted;
103+
}
104+
}
105+
}
106+
}
107+
});
108+
}
109+
110+
Statistic opsDeleted{this, "opsDeleted", "number of operations saved by mux hoisting"};
111+
};
112+
113+
std::unique_ptr<OperationPass<ModuleOp>> createHoistCommonMuxCodePass() {
114+
return std::make_unique<HoistCommonMuxCodePass>();
115+
}
116+
117+
} // namespace zirgen::Zhlt

zirgen/Dialect/ZHLT/Transforms/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace zirgen::Zhlt {
2121

2222
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createElideRedundantMembersPass();
2323
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createHoistAllocsPass();
24+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createHoistCommonMuxCodePass();
2425
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createStripTestsPass();
2526
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createGenerateStepsPass();
2627
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createStripAliasLayoutOpsPass();

zirgen/Dialect/ZHLT/Transforms/Passes.td

+6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def HoistAllocs : Pass<"hoist-allocs", "mlir::ModuleOp"> {
3838
let dependentDialects = ["zirgen::Zhlt::ZhltDialect"];
3939
}
4040

41+
def HoistCommonMuxCode : Pass<"hoist-from-mux", "mlir::ModuleOp"> {
42+
let summary = "Hoist code shared across all mux arms out of the mux";
43+
let constructor = "zirgen::Zhlt::createHoistCommonMuxCodePass()";
44+
let dependentDialects = ["zirgen::Zhlt::ZhltDialect"];
45+
}
46+
4147
def StripTests : Pass<"strip-tests", "mlir::ModuleOp"> {
4248
let summary = "Strip all tests for smaller generated code.";
4349
let constructor = "zirgen::Zhlt::createStripTestsPass()";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: zirgen %s --emit=zhltopt --debug-only=mux-hoisting 2>&1 | FileCheck %s
2+
3+
extern IsFirstCycle() : Val;
4+
5+
component Top() {
6+
a := Reg(5);
7+
first := NondetReg(IsFirstCycle());
8+
// Check that a@1 gets hoisted
9+
// CHECK: hoist: %22 = zhlt.back @Reg(1,
10+
if (first) {
11+
Reg(a@1)
12+
} else {
13+
Reg(a@1)
14+
}
15+
}

zirgen/dsl/driver.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ int main(int argc, char* argv[]) {
188188
// pm.addPass(zirgen::ZStruct::createOptimizeLayoutPass());
189189
pm.addPass(zirgen::Zhlt::createElideRedundantMembersPass());
190190
pm.addPass(zirgen::dsl::createFieldDCEPass());
191+
pm.addPass(zirgen::Zhlt::createHoistCommonMuxCodePass());
191192
pm.addPass(mlir::createCanonicalizerPass());
192193
pm.addPass(mlir::createCSEPass());
193194
if (failed(pm.run(typedModule.value()))) {

zirgen/dsl/examples/fibonacci.rs.inc

+12-17
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,10 @@ pub fn exec_cycle_counter<'a>(
293293
let x6: NondetRegStruct = exec_is_zero(ctx, x5._super, (layout0.map(|c| c.is_first_cycle)))?;
294294
// CycleCounter(zirgen/dsl/examples/fibonacci.zir:34)
295295
let x7: Val = exec_sub(ctx, Val::new(1), x6._super)?;
296-
let x8: ComponentStruct;
296+
let x8: ComponentStruct = exec_component(ctx)?;
297+
let x9: ComponentStruct;
297298
if is_true(x6._super) {
298-
let x9: ComponentStruct = exec_component(ctx)?;
299-
x8 = x9;
299+
x9 = x8;
300300
} else if is_true(x7) {
301301
// CycleCounter(zirgen/dsl/examples/fibonacci.zir:39)
302302
let x10: NondetRegStruct = back_nondet_reg(ctx, 1, (layout0.map(|c| c._super)))?;
@@ -305,9 +305,7 @@ pub fn exec_cycle_counter<'a>(
305305
(x5._super - x11),
306306
"CycleCounter(zirgen/dsl/examples/fibonacci.zir:39)"
307307
);
308-
// CycleCounter(zirgen/dsl/examples/fibonacci.zir:37)
309-
let x12: ComponentStruct = exec_component(ctx)?;
310-
x8 = x12;
308+
x9 = x8;
311309
} else {
312310
bail!("Reached unreachable mux arm")
313311
}
@@ -366,24 +364,21 @@ pub fn exec_top<'a>(
366364
let x19: NondetRegStruct = exec_is_zero(ctx, x18, (layout0.map(|c| c.terminate)))?;
367365
// Top(zirgen/dsl/examples/fibonacci.zir:63)
368366
let x20: Val = exec_sub(ctx, Val::new(1), x19._super)?;
369-
let x21: ComponentStruct;
367+
let x21: ComponentStruct = exec_component(ctx)?;
368+
let x22: ComponentStruct;
370369
if is_true(x19._super) {
371370
// Top(zirgen/dsl/examples/fibonacci.zir:64)
372-
let x22: NondetRegStruct = exec_reg(ctx, x15._super, (x2.map(|c| c.f_last)))?;
373-
let x23: NondetRegStruct = back_reg(ctx, 0, (x2.map(|c| c.f_last)))?;
371+
let x23: NondetRegStruct = exec_reg(ctx, x15._super, (x2.map(|c| c.f_last)))?;
372+
let x24: NondetRegStruct = back_reg(ctx, 0, (x2.map(|c| c.f_last)))?;
374373
// Top(zirgen/dsl/examples/fibonacci.zir:65)
375-
let x24: LogStruct = exec_log(ctx, "f_last = %u", &[x23._super])?;
376-
// Top(zirgen/dsl/examples/fibonacci.zir:63)
377-
let x25: ComponentStruct = exec_component(ctx)?;
378-
x21 = x25;
374+
let x25: LogStruct = exec_log(ctx, "f_last = %u", &[x24._super])?;
375+
x22 = x21;
379376
} else if is_true(x20) {
380-
// Top(zirgen/dsl/examples/fibonacci.zir:66)
381-
let x26: ComponentStruct = exec_component(ctx)?;
382-
x21 = x26;
377+
x22 = x21;
383378
} else {
384379
bail!("Reached unreachable mux arm")
385380
} // Top(zirgen/dsl/examples/fibonacci.zir:44)
386-
let x27: ComponentStruct = exec_component(ctx)?;
381+
let x26: ComponentStruct = exec_component(ctx)?;
387382
return Ok(TopStruct { terminate: x19 });
388383
}
389384
pub fn step_top<'a>(

0 commit comments

Comments
 (0)