1
+ import abc
1
2
import itertools
2
3
import operator
3
4
import sys
4
5
from collections import defaultdict , deque
5
- from collections .abc import Generator
6
+ from collections .abc import Generator , Sequence
6
7
from functools import cache , reduce
7
8
from typing import TypeVar
8
9
from warnings import warn
12
13
from pytensor .compile .function .types import Supervisor
13
14
from pytensor .compile .mode import get_target_language
14
15
from pytensor .configdefaults import config
15
- from pytensor .graph import FunctionGraph
16
+ from pytensor .graph import FunctionGraph , Op
16
17
from pytensor .graph .basic import Apply , Variable , ancestors
17
18
from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
18
19
from pytensor .graph .features import ReplaceValidate
47
48
from pytensor .tensor .variable import TensorConstant , TensorVariable
48
49
49
50
50
- class InplaceElemwiseOptimizer (GraphRewriter ):
51
+ class InplaceGraphOptimizer (GraphRewriter ):
51
52
r"""
52
53
This is parameterized so that it works for `Elemwise` `Op`\s.
53
54
"""
54
55
56
+ op : type [Op ]
57
+
55
58
def add_requirements (self , fgraph ):
56
59
fgraph .attach_feature (DestroyHandler ())
57
60
61
+ @abc .abstractmethod
62
+ def filter_candidate_pairs (
63
+ self , fgraph : FunctionGraph , node : Apply , protected_inputs : Sequence [Variable ]
64
+ ) -> Sequence [tuple [tuple [int , Variable ], tuple [int , Variable ]]]:
65
+ pass
66
+
67
+ @abc .abstractmethod
68
+ def create_inplace_node (
69
+ self , node : Apply , inplace_pattern : dict [int , Sequence [int ]]
70
+ ) -> Apply :
71
+ pass
72
+
58
73
def apply (self , fgraph ):
59
74
r"""
60
75
@@ -93,30 +108,6 @@ def apply(self, fgraph):
93
108
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
94
109
# We can consider restricting inputs with static shapes that are large enough.
95
110
96
- def create_inplace_node (node , inplace_pattern ):
97
- op = node .op
98
- scalar_op = op .scalar_op
99
- inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
100
- if hasattr (scalar_op , "make_new_inplace" ):
101
- new_scalar_op = scalar_op .make_new_inplace (
102
- ps .transfer_type (
103
- * [
104
- inplace_pattern .get (i , o .dtype )
105
- for i , o in enumerate (node .outputs )
106
- ]
107
- )
108
- )
109
- else :
110
- new_scalar_op = type (scalar_op )(
111
- ps .transfer_type (
112
- * [
113
- inplace_pattern .get (i , None )
114
- for i in range (len (node .outputs ))
115
- ]
116
- )
117
- )
118
- return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
119
-
120
111
if config .tensor__insert_inplace_optimizer_validate_nb != - 1 :
121
112
warn (
122
113
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release." ,
@@ -140,43 +131,30 @@ def create_inplace_node(node, inplace_pattern):
140
131
protected_inputs .update (fgraph .outputs )
141
132
root_destroyer = fgraph .destroy_handler .root_destroyer
142
133
134
+ self_op = self .op
143
135
update_mapping = fgraph .update_mapping or {}
144
136
op_updates : dict [TensorVariable , TensorVariable ] = {
145
137
out : fgraph .inputs [update_mapping [out_idx ]]
146
138
for out_idx , out in enumerate (fgraph .outputs )
147
139
if (
148
140
out_idx in update_mapping
149
141
and out .owner
150
- and isinstance (out .owner .op , Elemwise )
142
+ and isinstance (out .owner .op , self_op )
151
143
)
152
144
}
153
145
set_op_updates = set (op_updates .keys ())
154
146
155
147
for node in fgraph .toposort ():
156
- if not isinstance (node .op , Elemwise ) or node .op .destroy_map :
148
+ if not isinstance (node .op , self_op ) or node .op .destroy_map :
157
149
continue
158
150
159
151
# If big graph and the outputs are scalar, do not make it inplace.
160
152
if large_graph and all (node .outputs [0 ].type .broadcastable ):
161
153
continue
162
154
163
- candidate_inputs = [
164
- (node .inputs .index (inp ), inp )
165
- for inp in inplace_candidates (
166
- fgraph ,
167
- node .inputs ,
168
- protected_inputs = protected_inputs ,
169
- )
170
- ]
171
- if not candidate_inputs :
172
- return []
173
-
174
- candidate_pairs = [
175
- ((o , out ), (i , inp ))
176
- for o , out in enumerate (node .outputs )
177
- for i , inp in candidate_inputs
178
- if inp .type == out .type
179
- ]
155
+ candidate_pairs = self .filter_candidate_pairs (
156
+ fgraph , node , protected_inputs
157
+ )
180
158
181
159
if not candidate_pairs :
182
160
continue
@@ -216,7 +194,7 @@ def create_inplace_node(node, inplace_pattern):
216
194
inplace_pattern [o ] = [i ]
217
195
tried_inputs .add (i )
218
196
219
- inplace_node = create_inplace_node (node , inplace_pattern )
197
+ inplace_node = self . create_inplace_node (node , inplace_pattern )
220
198
if inplace_node .op .destroy_map == inplace_pattern :
221
199
replacements = tuple (zip (node .outputs , inplace_node .outputs ))
222
200
try :
@@ -238,7 +216,7 @@ def create_inplace_node(node, inplace_pattern):
238
216
inplace_pattern [o ] = [i ]
239
217
tried_inputs .add (i )
240
218
241
- inplace_node = create_inplace_node (node , inplace_pattern )
219
+ inplace_node = self . create_inplace_node (node , inplace_pattern )
242
220
if inplace_node .op .destroy_map != inplace_pattern :
243
221
# This Op can't respect this partial inplace pattern,
244
222
# We assume it can't support any other cases
@@ -249,6 +227,7 @@ def create_inplace_node(node, inplace_pattern):
249
227
fgraph .replace_all_validate (
250
228
replacements , reason = "inplace_elemwise_optimizer"
251
229
)
230
+ node = inplace_node
252
231
replaced = True
253
232
node = inplace_node
254
233
except InconsistencyError :
@@ -278,6 +257,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
278
257
)
279
258
280
259
260
+ class InplaceElemwiseOptimizer (InplaceGraphOptimizer ):
261
+ op = Elemwise
262
+
263
+ def filter_candidate_pairs (self , fgraph , node , protected_inputs ):
264
+ candidate_inputs = [
265
+ (node .inputs .index (inp ), inp )
266
+ for inp in inplace_candidates (
267
+ fgraph ,
268
+ node .inputs ,
269
+ protected_inputs = protected_inputs ,
270
+ )
271
+ ]
272
+ if not candidate_inputs :
273
+ return []
274
+
275
+ return [
276
+ ((o , out ), (i , inp ))
277
+ for o , out in enumerate (node .outputs )
278
+ for i , inp in candidate_inputs
279
+ if inp .type == out .type
280
+ ]
281
+
282
+ def create_inplace_node (self , node , inplace_pattern ):
283
+ op = node .op
284
+ scalar_op = op .scalar_op
285
+ inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
286
+ if hasattr (scalar_op , "make_new_inplace" ):
287
+ new_scalar_op = scalar_op .make_new_inplace (
288
+ ps .transfer_type (
289
+ * [
290
+ inplace_pattern .get (i , o .dtype )
291
+ for i , o in enumerate (node .outputs )
292
+ ]
293
+ )
294
+ )
295
+ else :
296
+ new_scalar_op = type (scalar_op )(
297
+ ps .transfer_type (
298
+ * [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
299
+ )
300
+ )
301
+ return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
302
+
303
+
281
304
compile .optdb .register (
282
305
"inplace_elemwise" ,
283
306
InplaceElemwiseOptimizer (),
0 commit comments