Skip to content

Commit 68b68c4

Browse files
feat[venom]: improve phi elimination pass (#4635)
this commit improves the phi elimination pass (introduced in 6b670e7) when there are multiple phis in the chain. if the innermost phi has multiple origins, it would block elimination of outer phis which use that phi. for example, ```llvm %a = 1 %b = 2 %c = phi %a, %b ; has two origins %d = %c %e = %d %f = phi %d, %e ``` prior to this commit, `%f = phi %d, %e` would not reduce, because its origins were detected as `%a = 1` *and* `%b = 2`. this commit modifies the graph traversal so that further traversal of `%c`'s origins is blocked from the perspective of `%f`, allowing `%f` to be reduced to `%c`. --------- Co-authored-by: Charles Cooper <[email protected]>
1 parent 4b2e1db commit 68b68c4

File tree

5 files changed

+130
-63
lines changed

5 files changed

+130
-63
lines changed

tests/unit/compiler/venom/test_phi_elimination.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def test_phi_elim_cannot_remove():
212212
main:
213213
%p = param
214214
%rand = param
215+
jmp @cond
215216
cond:
216217
%1 = phi @main, %p, @body, %3
217218
%cond = iszero %1
@@ -264,3 +265,57 @@ def test_phi_elim_direct_loop():
264265

265266
_check_pre_post(pre1, post)
266267
_check_pre_post(pre2, post)
268+
269+
270+
def test_phi_elim_two_phi_merges():
271+
pre = """
272+
main:
273+
%cond = param
274+
%cond2 = param
275+
jnz %cond, @1_then, @2_then
276+
1_then:
277+
%1 = 100
278+
jmp @3_join
279+
2_then:
280+
%2 = 101
281+
jmp @3_join
282+
3_join:
283+
%3 = phi @1_then, %1, @2_then, %2
284+
jnz %cond2, @4_then, @5_then
285+
4_then:
286+
%4 = %3
287+
jmp @6_join
288+
5_then:
289+
%5 = %3
290+
jmp @6_join
291+
6_join:
292+
%6 = phi @4_then, %4, @5_then, %5 ; should be reduced to %3!
293+
sink %6
294+
"""
295+
296+
post = """
297+
main:
298+
%cond = param
299+
%cond2 = param
300+
jnz %cond, @1_then, @2_then
301+
1_then:
302+
%1 = 100
303+
jmp @3_join
304+
2_then:
305+
%2 = 101
306+
jmp @3_join
307+
3_join:
308+
%3 = phi @1_then, %1, @2_then, %2
309+
jnz %cond2, @4_then, @5_then
310+
4_then:
311+
%4 = %3
312+
jmp @6_join
313+
5_then:
314+
%5 = %3
315+
jmp @6_join
316+
6_join:
317+
%6 = %3
318+
sink %6
319+
"""
320+
321+
_check_pre_post(pre, post, hevm=True)

vyper/venom/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,13 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache
6060
SimplifyCFGPass(ac, fn).run_pass()
6161

6262
MakeSSA(ac, fn).run_pass()
63+
PhiEliminationPass(ac, fn).run_pass()
6364
# run algebraic opts before mem2var to reduce some pointer arithmetic
6465
AlgebraicOptimizationPass(ac, fn).run_pass()
6566
AssignElimination(ac, fn).run_pass()
6667
Mem2Var(ac, fn).run_pass()
6768
MakeSSA(ac, fn).run_pass()
69+
PhiEliminationPass(ac, fn).run_pass()
6870
SCCP(ac, fn).run_pass()
6971

7072
SimplifyCFGPass(ac, fn).run_pass()

vyper/venom/analysis/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,5 @@
66
from .liveness import LivenessAnalysis
77
from .mem_alias import MemoryAliasAnalysis
88
from .mem_ssa import MemSSA
9-
from .phi_reach import PhiReachingAnalysis
109
from .reachable import ReachableAnalysis
1110
from .var_definition import VarDefinition

vyper/venom/analysis/phi_reach.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

vyper/venom/passes/phi_elimination.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,97 @@
1-
from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, PhiReachingAnalysis
2-
from vyper.venom.basicblock import IRInstruction
1+
from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis
2+
from vyper.venom.basicblock import IRInstruction, IRVariable
33
from vyper.venom.passes.base_pass import InstUpdater, IRPass
44

55

66
class PhiEliminationPass(IRPass):
7-
phi_reach: PhiReachingAnalysis
7+
phi_to_origins: dict[IRInstruction, set[IRInstruction]]
88

99
def run_pass(self):
1010
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
1111
self.updater = InstUpdater(self.dfg)
12-
self.phi_reach = self.analyses_cache.request_analysis(PhiReachingAnalysis)
12+
self._calculate_phi_origins()
1313

1414
for _, inst in self.dfg.outputs.copy().items():
1515
if inst.opcode != "phi":
1616
continue
1717
self._process_phi(inst)
1818

19+
# sort phis to top of basic block
1920
for bb in self.function.get_basic_blocks():
2021
bb.ensure_well_formed()
2122

2223
self.analyses_cache.invalidate_analysis(LivenessAnalysis)
2324

2425
def _process_phi(self, inst: IRInstruction):
25-
srcs = self.phi_reach.phi_to_origins[inst]
26+
srcs = self.phi_to_origins[inst]
2627

2728
if len(srcs) == 1:
28-
src = next(iter(srcs))
29+
src = srcs.pop()
30+
if src == inst:
31+
return
2932
assert src.output is not None
3033
self.updater.store(inst, src.output)
34+
35+
def _calculate_phi_origins(self):
36+
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
37+
self.phi_to_origins = dict()
38+
39+
for bb in self.function.get_basic_blocks():
40+
for inst in bb.instructions:
41+
if inst.opcode != "phi":
42+
break
43+
self._get_phi_origins(inst)
44+
45+
def _get_phi_origins(self, inst: IRInstruction):
46+
assert inst.opcode == "phi" # sanity
47+
visited: set[IRInstruction] = set()
48+
self.phi_to_origins[inst] = self._get_phi_origins_r(inst, visited)
49+
50+
# traverse chains of phis and stores to get the "root" instructions
51+
# for phis.
52+
def _get_phi_origins_r(
53+
self, inst: IRInstruction, visited: set[IRInstruction]
54+
) -> set[IRInstruction]:
55+
if inst.opcode == "phi":
56+
if inst in self.phi_to_origins:
57+
return self.phi_to_origins[inst]
58+
59+
if inst in visited:
60+
# we have hit a dfg cycle. break the recursion.
61+
# if it is only visited we have found a self
62+
# reference, and we won't find anything more by
63+
# continuing the recursion.
64+
return set()
65+
66+
visited.add(inst)
67+
68+
res: set[IRInstruction] = set()
69+
70+
for _, var in inst.phi_operands:
71+
next_inst = self.dfg.get_producing_instruction(var)
72+
assert next_inst is not None, (inst, var)
73+
res |= self._get_phi_origins_r(next_inst, visited)
74+
75+
if len(res) > 1:
76+
# if this phi has more than one origin, then for future
77+
# phis, it is better to treat this as a barrier in the
78+
# graph traversal. for example (without basic blocks)
79+
# %a = 1
80+
# %b = 2
81+
# %c = phi %a, %b ; has two origins
82+
# %d = %c
83+
# %e = %d
84+
# %f = phi %d, %e
85+
# in this case, %f should reduce to %c.
86+
return set([inst])
87+
return res
88+
89+
if inst.opcode == "store" and isinstance(inst.operands[0], IRVariable):
90+
# traverse store chain
91+
var = inst.operands[0]
92+
next_inst = self.dfg.get_producing_instruction(var)
93+
assert next_inst is not None
94+
return self._get_phi_origins_r(next_inst, visited)
95+
96+
# root of the phi/store chain
97+
return set([inst])

0 commit comments

Comments
 (0)