2
2
3
3
from collections import deque
4
4
from dataclasses import dataclass
5
- from functools import cached_property
5
+ from functools import cached_property , lru_cache
6
+
7
+ import immutables
6
8
7
9
import vyper .venom .effects as effects
8
10
from vyper .venom .analysis .analysis import IRAnalysesCache , IRAnalysis
35
37
assert opcode in NONIDEMPOTENT_INSTRUCTIONS
36
38
37
39
40
+ # flag bitwise operations are somehow a perf bottleneck, cache them
41
+ @lru_cache
42
+ def _get_read_effects (opcode , ignore_msize ):
43
+ ret = effects .reads .get (opcode , effects .EMPTY )
44
+ if ignore_msize :
45
+ ret &= ~ Effects .MSIZE
46
+ return ret
47
+
48
+
49
+ @lru_cache
50
+ def _get_write_effects (opcode , ignore_msize ):
51
+ ret = effects .writes .get (opcode , effects .EMPTY )
52
+ if ignore_msize :
53
+ ret &= ~ Effects .MSIZE
54
+ return ret
55
+
56
+
57
+ @lru_cache
58
+ def _get_overlap_effects (opcode , ignore_msize ):
59
+ return _get_read_effects (opcode , ignore_msize ) & _get_write_effects (opcode , ignore_msize )
60
+
61
+
62
+ @lru_cache
63
+ def _get_effects (opcode , ignore_msize ):
64
+ return _get_read_effects (opcode , ignore_msize ) | _get_write_effects (opcode , ignore_msize )
65
+
66
+
38
67
@dataclass
39
68
class _Expression :
40
69
opcode : str
@@ -94,18 +123,6 @@ def depth(self) -> int:
94
123
max_depth = d
95
124
return max_depth + 1
96
125
97
- def get_reads (self , ignore_msize ) -> Effects :
98
- ret = effects .reads .get (self .opcode , effects .EMPTY )
99
- if ignore_msize :
100
- ret &= ~ Effects .MSIZE
101
- return ret
102
-
103
- def get_writes (self , ignore_msize ) -> Effects :
104
- ret = effects .writes .get (self .opcode , effects .EMPTY )
105
- if ignore_msize :
106
- ret &= ~ Effects .MSIZE
107
- return ret
108
-
109
126
@property
110
127
def is_commutative (self ) -> bool :
111
128
return self .opcode in COMMUTATIVE_INSTRUCTIONS
@@ -130,10 +147,10 @@ class _AvailableExpressions:
130
147
and provides API for handling them
131
148
"""
132
149
133
- exprs : dict [_Expression , list [IRInstruction ]]
150
+ exprs : immutables . Map [_Expression , list [IRInstruction ]]
134
151
135
152
def __init__ (self ):
136
- self .exprs = dict ()
153
+ self .exprs = immutables . Map ()
137
154
138
155
def __eq__ (self , other ) -> bool :
139
156
if not isinstance (other , _AvailableExpressions ):
@@ -148,23 +165,27 @@ def __repr__(self) -> str:
148
165
return res
149
166
150
167
def add (self , expr : _Expression , src_inst : IRInstruction ):
151
- if expr not in self .exprs :
152
- self .exprs [expr ] = []
153
- self .exprs [expr ].append (src_inst )
168
+ with self .exprs .mutate () as mt :
169
+ if expr not in mt :
170
+ mt [expr ] = []
171
+ else :
172
+ mt [expr ] = mt [expr ].copy ()
173
+ mt [expr ].append (src_inst )
174
+ self .exprs = mt .finish ()
154
175
155
176
def remove_effect (self , effect : Effects , ignore_msize ):
156
177
if effect == effects .EMPTY :
157
178
return
158
179
to_remove = set ()
159
180
for expr in self .exprs .keys ():
160
- read_effs = expr .get_reads (ignore_msize )
161
- write_effs = expr .get_writes (ignore_msize )
162
- op_effect = read_effs | write_effs
181
+ op_effect = _get_effects (expr .opcode , ignore_msize )
163
182
if op_effect & effect != effects .EMPTY :
164
183
to_remove .add (expr )
165
184
166
- for expr in to_remove :
167
- del self .exprs [expr ]
185
+ with self .exprs .mutate () as mt :
186
+ for expr in to_remove :
187
+ del mt [expr ]
188
+ self .exprs = mt .finish ()
168
189
169
190
def get_source_instruction (self , expr : _Expression ) -> IRInstruction | None :
170
191
"""
@@ -178,18 +199,19 @@ def get_source_instruction(self, expr: _Expression) -> IRInstruction | None:
178
199
179
200
def copy (self ) -> _AvailableExpressions :
180
201
res = _AvailableExpressions ()
181
- for k , v in self .exprs .items ():
182
- res .exprs [k ] = v .copy ()
202
+ res .exprs = self .exprs
183
203
return res
184
204
185
205
@staticmethod
186
206
def lattice_meet (lattices : list [_AvailableExpressions ]):
187
207
if len (lattices ) == 0 :
188
208
return _AvailableExpressions ()
189
209
res = lattices [0 ].copy ()
210
+ # compute intersection
190
211
for item in lattices [1 :]:
191
212
tmp = res
192
213
res = _AvailableExpressions ()
214
+ mt = res .exprs .mutate ()
193
215
for expr , insts in item .exprs .items ():
194
216
if expr not in tmp .exprs :
195
217
continue
@@ -199,7 +221,8 @@ def lattice_meet(lattices: list[_AvailableExpressions]):
199
221
new_insts .append (i )
200
222
if len (new_insts ) == 0 :
201
223
continue
202
- res .exprs [expr ] = new_insts
224
+ mt [expr ] = new_insts
225
+ res .exprs = mt .finish ()
203
226
return res
204
227
205
228
@@ -279,7 +302,7 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool:
279
302
280
303
self ._update_expr (inst , expr )
281
304
282
- write_effects = expr .get_writes ( self .ignore_msize )
305
+ write_effects = _get_write_effects ( expr .opcode , self .ignore_msize )
283
306
available_exprs .remove_effect (write_effects , self .ignore_msize )
284
307
285
308
# nonidempotent instructions affect other instructions,
@@ -288,7 +311,7 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool:
288
311
if inst .opcode in NONIDEMPOTENT_INSTRUCTIONS :
289
312
continue
290
313
291
- expr_effects = expr . get_writes ( self . ignore_msize ) & expr .get_reads ( self .ignore_msize )
314
+ expr_effects = _get_overlap_effects ( expr .opcode , self .ignore_msize )
292
315
if expr_effects == effects .EMPTY :
293
316
available_exprs .add (expr , inst )
294
317
0 commit comments