Skip to content

Commit d47ea81

Browse files
perf[venom]: improve CSE elimination speed (#4607)
in the CSE elimination pass, avoid `O(n*k)` (where `n` == number of instructions, and `k` == number of available expressions) copies by replacing the dict with a HAMT. HAMT is an immutable dict with `O(1)` copy (copy by reference), and `O(log N)` insert/delete/update operations, each returning a fresh reference to a new HAMT. we pull in the external package `immutables`, which is stable, shares maintainership with CPython, and seems mature (receives bugfixes every now and again, but mostly not adding new features). we pull in the external library since HAMTs should be generally useful for the types of lattice operations we do, and we can probably use them to optimize other passes as well. this commit also adds the use of `lru_cache` for some `Flag` operations (which are bizarrely a perf hotspot)
1 parent 2fa2dcc commit d47ea81

File tree

3 files changed

+59
-31
lines changed

3 files changed

+59
-31
lines changed

.github/workflows/build.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,13 @@ jobs:
4747
cache: "pip"
4848

4949
- name: Generate Binary
50-
run: >-
51-
pip install --no-binary pycryptodome --no-binary cbor2 . &&
52-
pip install pyinstaller &&
50+
run: |
51+
pip install \
52+
--no-binary pycryptodome \
53+
--no-binary cbor2 \
54+
--no-binary immutables \
55+
. && \
56+
pip install pyinstaller && \
5357
make freeze
5458
5559

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _global_version(version):
9696
"packaging>=23.1",
9797
"lark>=1.0.0,<2",
9898
"wheel",
99+
"immutables",
99100
],
100101
setup_requires=["setuptools_scm>=7.1.0,<8.0.0"],
101102
extras_require=extras_require,

vyper/venom/analysis/available_expression.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from collections import deque
44
from dataclasses import dataclass
5-
from functools import cached_property
5+
from functools import cached_property, lru_cache
6+
7+
import immutables
68

79
import vyper.venom.effects as effects
810
from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis
@@ -35,6 +37,33 @@
3537
assert opcode in NONIDEMPOTENT_INSTRUCTIONS
3638

3739

40+
# flag bitwise operations are somehow a perf bottleneck, cache them
41+
@lru_cache
42+
def _get_read_effects(opcode, ignore_msize):
43+
ret = effects.reads.get(opcode, effects.EMPTY)
44+
if ignore_msize:
45+
ret &= ~Effects.MSIZE
46+
return ret
47+
48+
49+
@lru_cache
50+
def _get_write_effects(opcode, ignore_msize):
51+
ret = effects.writes.get(opcode, effects.EMPTY)
52+
if ignore_msize:
53+
ret &= ~Effects.MSIZE
54+
return ret
55+
56+
57+
@lru_cache
58+
def _get_overlap_effects(opcode, ignore_msize):
59+
return _get_read_effects(opcode, ignore_msize) & _get_write_effects(opcode, ignore_msize)
60+
61+
62+
@lru_cache
63+
def _get_effects(opcode, ignore_msize):
64+
return _get_read_effects(opcode, ignore_msize) | _get_write_effects(opcode, ignore_msize)
65+
66+
3867
@dataclass
3968
class _Expression:
4069
opcode: str
@@ -94,18 +123,6 @@ def depth(self) -> int:
94123
max_depth = d
95124
return max_depth + 1
96125

97-
def get_reads(self, ignore_msize) -> Effects:
98-
ret = effects.reads.get(self.opcode, effects.EMPTY)
99-
if ignore_msize:
100-
ret &= ~Effects.MSIZE
101-
return ret
102-
103-
def get_writes(self, ignore_msize) -> Effects:
104-
ret = effects.writes.get(self.opcode, effects.EMPTY)
105-
if ignore_msize:
106-
ret &= ~Effects.MSIZE
107-
return ret
108-
109126
@property
110127
def is_commutative(self) -> bool:
111128
return self.opcode in COMMUTATIVE_INSTRUCTIONS
@@ -130,10 +147,10 @@ class _AvailableExpressions:
130147
and provides API for handling them
131148
"""
132149

133-
exprs: dict[_Expression, list[IRInstruction]]
150+
exprs: immutables.Map[_Expression, list[IRInstruction]]
134151

135152
def __init__(self):
136-
self.exprs = dict()
153+
self.exprs = immutables.Map()
137154

138155
def __eq__(self, other) -> bool:
139156
if not isinstance(other, _AvailableExpressions):
@@ -148,23 +165,27 @@ def __repr__(self) -> str:
148165
return res
149166

150167
def add(self, expr: _Expression, src_inst: IRInstruction):
151-
if expr not in self.exprs:
152-
self.exprs[expr] = []
153-
self.exprs[expr].append(src_inst)
168+
with self.exprs.mutate() as mt:
169+
if expr not in mt:
170+
mt[expr] = []
171+
else:
172+
mt[expr] = mt[expr].copy()
173+
mt[expr].append(src_inst)
174+
self.exprs = mt.finish()
154175

155176
def remove_effect(self, effect: Effects, ignore_msize):
156177
if effect == effects.EMPTY:
157178
return
158179
to_remove = set()
159180
for expr in self.exprs.keys():
160-
read_effs = expr.get_reads(ignore_msize)
161-
write_effs = expr.get_writes(ignore_msize)
162-
op_effect = read_effs | write_effs
181+
op_effect = _get_effects(expr.opcode, ignore_msize)
163182
if op_effect & effect != effects.EMPTY:
164183
to_remove.add(expr)
165184

166-
for expr in to_remove:
167-
del self.exprs[expr]
185+
with self.exprs.mutate() as mt:
186+
for expr in to_remove:
187+
del mt[expr]
188+
self.exprs = mt.finish()
168189

169190
def get_source_instruction(self, expr: _Expression) -> IRInstruction | None:
170191
"""
@@ -178,18 +199,19 @@ def get_source_instruction(self, expr: _Expression) -> IRInstruction | None:
178199

179200
def copy(self) -> _AvailableExpressions:
180201
res = _AvailableExpressions()
181-
for k, v in self.exprs.items():
182-
res.exprs[k] = v.copy()
202+
res.exprs = self.exprs
183203
return res
184204

185205
@staticmethod
186206
def lattice_meet(lattices: list[_AvailableExpressions]):
187207
if len(lattices) == 0:
188208
return _AvailableExpressions()
189209
res = lattices[0].copy()
210+
# compute intersection
190211
for item in lattices[1:]:
191212
tmp = res
192213
res = _AvailableExpressions()
214+
mt = res.exprs.mutate()
193215
for expr, insts in item.exprs.items():
194216
if expr not in tmp.exprs:
195217
continue
@@ -199,7 +221,8 @@ def lattice_meet(lattices: list[_AvailableExpressions]):
199221
new_insts.append(i)
200222
if len(new_insts) == 0:
201223
continue
202-
res.exprs[expr] = new_insts
224+
mt[expr] = new_insts
225+
res.exprs = mt.finish()
203226
return res
204227

205228

@@ -279,7 +302,7 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool:
279302

280303
self._update_expr(inst, expr)
281304

282-
write_effects = expr.get_writes(self.ignore_msize)
305+
write_effects = _get_write_effects(expr.opcode, self.ignore_msize)
283306
available_exprs.remove_effect(write_effects, self.ignore_msize)
284307

285308
# nonidempotent instructions affect other instructions,
@@ -288,7 +311,7 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool:
288311
if inst.opcode in NONIDEMPOTENT_INSTRUCTIONS:
289312
continue
290313

291-
expr_effects = expr.get_writes(self.ignore_msize) & expr.get_reads(self.ignore_msize)
314+
expr_effects = _get_overlap_effects(expr.opcode, self.ignore_msize)
292315
if expr_effects == effects.EMPTY:
293316
available_exprs.add(expr, inst)
294317

0 commit comments

Comments
 (0)