Skip to content

Commit 6475c85

Browse files
committed
Reuse Elemwise inplace machinery for Blockwise
1 parent 480ef20 commit 6475c85

File tree

3 files changed

+196
-100
lines changed

3 files changed

+196
-100
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.graph import Constant, node_rewriter
33
from pytensor.graph.destroyhandler import inplace_candidates
44
from pytensor.graph.replace import vectorize_node
5-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
5+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
66
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.math import Dot
@@ -11,6 +11,7 @@
1111
register_specialize,
1212
register_stabilize,
1313
)
14+
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
1415
from pytensor.tensor.shape import Reshape
1516
from pytensor.tensor.subtensor import (
1617
AdvancedIncSubtensor,
@@ -260,68 +261,76 @@ def local_blockwise_of_subtensor(fgraph, node):
260261
return [x[(*none_slices, *core_idxs)]]
261262

262263

263-
@node_rewriter(tracks=[Blockwise], inplace=True)
264-
def blockwise_inplace(fgraph, node):
265-
blockwise_op = node.op
264+
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
265+
op = Blockwise
266266

267-
if blockwise_op.destroy_map:
268-
# Op already has inplace
269-
return
270-
271-
# Find out valid inputs for inplacing
272-
batch_ndim = blockwise_op.batch_ndim(node)
273-
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
267+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
268+
blockwise_op = node.op
269+
batch_ndim = blockwise_op.batch_ndim(node)
270+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
271+
inputs = node.inputs
274272

275-
inputs = node.inputs
276-
candidate_inputs = set(
277-
inplace_candidates(
278-
fgraph,
279-
[
280-
inp
281-
for inp in inputs
282-
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
283-
],
273+
candidate_inputs = set(
274+
inplace_candidates(
275+
fgraph,
276+
[
277+
inp
278+
for inp in inputs
279+
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
280+
],
281+
)
284282
)
285-
)
286-
allowed_inplace_inputs = [
287-
i for i, inp in enumerate(inputs) if inp in candidate_inputs
288-
]
289283

290-
if not allowed_inplace_inputs:
291-
return None
284+
allowed_inplace_inputs = [
285+
i for i, inp in enumerate(inputs) if inp in candidate_inputs
286+
]
287+
destroy_map = blockwise_op.core_op.inplace_on_inputs(
288+
allowed_inplace_inputs=allowed_inplace_inputs
289+
).destroy_map
290+
291+
if not destroy_map:
292+
return []
293+
294+
outputs = node.outputs
295+
return [
296+
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
297+
for out_idx, inp_idxs in destroy_map.items()
298+
for inp_idx in inp_idxs
299+
]
292300

293-
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
294-
allowed_inplace_inputs=allowed_inplace_inputs
295-
)
301+
def create_inplace_node(self, node, inplace_pattern):
302+
blockwise_op = node.op
303+
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
304+
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
305+
allowed_inplace_inputs=allowed_inplace_inputs
306+
)
296307

297-
if not inplace_core_op.destroy_map:
298-
return None
308+
if not inplace_core_op.destroy_map:
309+
return node
299310

300-
# Check Op is not trying to inplace on non-candidate inputs
301-
for destroyed_inputs in inplace_core_op.destroy_map.values():
302-
for destroyed_input in destroyed_inputs:
303-
if destroyed_input not in allowed_inplace_inputs:
304-
raise ValueError(
305-
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
306-
)
311+
# Check Op is not trying to inplace on non-candidate inputs
312+
for destroyed_inputs in inplace_core_op.destroy_map.values():
313+
for destroyed_input in destroyed_inputs:
314+
if destroyed_input not in allowed_inplace_inputs:
315+
raise ValueError(
316+
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
317+
)
307318

308-
# Recreate core_op with inplace
309-
inplace_blockwise_op = Blockwise(
310-
core_op=inplace_core_op,
311-
signature=blockwise_op.signature,
312-
name=blockwise_op.name,
313-
gufunc_spec=blockwise_op.gufunc_spec,
314-
destroy_map=inplace_core_op.destroy_map,
315-
)
319+
# Recreate core_op with inplace
320+
inplace_blockwise_op = type(blockwise_op)(
321+
core_op=inplace_core_op,
322+
signature=blockwise_op.signature,
323+
name=blockwise_op.name,
324+
gufunc_spec=blockwise_op.gufunc_spec,
325+
destroy_map=inplace_core_op.destroy_map,
326+
)
316327

317-
out = inplace_blockwise_op.make_node(*node.inputs).outputs
318-
copy_stack_trace(node.outputs, out)
319-
return out
328+
return inplace_blockwise_op.make_node(*node.inputs)
320329

321330

322331
optdb.register(
323332
"blockwise_inplace",
324-
in2out(blockwise_inplace),
333+
InplaceBlockwiseOptimizer(),
325334
"fast_run",
326335
"inplace",
327336
position=50.1,

pytensor/tensor/rewriting/elemwise.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import abc
12
import itertools
23
import operator
34
import sys
45
from collections import defaultdict, deque
5-
from collections.abc import Generator
6+
from collections.abc import Generator, Sequence
67
from functools import cache, reduce
78
from typing import TypeVar
89
from warnings import warn
@@ -12,7 +13,7 @@
1213
from pytensor.compile.function.types import Supervisor
1314
from pytensor.compile.mode import get_target_language
1415
from pytensor.configdefaults import config
15-
from pytensor.graph import FunctionGraph
16+
from pytensor.graph import FunctionGraph, Op
1617
from pytensor.graph.basic import Apply, Variable, ancestors
1718
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1819
from pytensor.graph.features import ReplaceValidate
@@ -47,14 +48,28 @@
4748
from pytensor.tensor.variable import TensorConstant, TensorVariable
4849

4950

50-
class InplaceElemwiseOptimizer(GraphRewriter):
51+
class InplaceGraphOptimizer(GraphRewriter):
5152
r"""
5253
This is parameterized so that it works for `Elemwise` `Op`\s.
5354
"""
5455

56+
op: type[Op]
57+
5558
def add_requirements(self, fgraph):
5659
fgraph.attach_feature(DestroyHandler())
5760

61+
@abc.abstractmethod
62+
def filter_candidate_pairs(
63+
self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable]
64+
) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]:
65+
pass
66+
67+
@abc.abstractmethod
68+
def create_inplace_node(
69+
self, node: Apply, inplace_pattern: dict[int, Sequence[int]]
70+
) -> Apply:
71+
pass
72+
5873
def apply(self, fgraph):
5974
r"""
6075
@@ -93,30 +108,6 @@ def apply(self, fgraph):
93108
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
94109
# We can consider restricting inputs with static shapes that are large enough.
95110

96-
def create_inplace_node(node, inplace_pattern):
97-
op = node.op
98-
scalar_op = op.scalar_op
99-
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
100-
if hasattr(scalar_op, "make_new_inplace"):
101-
new_scalar_op = scalar_op.make_new_inplace(
102-
ps.transfer_type(
103-
*[
104-
inplace_pattern.get(i, o.dtype)
105-
for i, o in enumerate(node.outputs)
106-
]
107-
)
108-
)
109-
else:
110-
new_scalar_op = type(scalar_op)(
111-
ps.transfer_type(
112-
*[
113-
inplace_pattern.get(i, None)
114-
for i in range(len(node.outputs))
115-
]
116-
)
117-
)
118-
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
119-
120111
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
121112
warn(
122113
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
@@ -140,43 +131,30 @@ def create_inplace_node(node, inplace_pattern):
140131
protected_inputs.update(fgraph.outputs)
141132
root_destroyer = fgraph.destroy_handler.root_destroyer
142133

134+
self_op = self.op
143135
update_mapping = fgraph.update_mapping or {}
144136
op_updates: dict[TensorVariable, TensorVariable] = {
145137
out: fgraph.inputs[update_mapping[out_idx]]
146138
for out_idx, out in enumerate(fgraph.outputs)
147139
if (
148140
out_idx in update_mapping
149141
and out.owner
150-
and isinstance(out.owner.op, Elemwise)
142+
and isinstance(out.owner.op, self_op)
151143
)
152144
}
153145
set_op_updates = set(op_updates.keys())
154146

155147
for node in fgraph.toposort():
156-
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
148+
if not isinstance(node.op, self_op) or node.op.destroy_map:
157149
continue
158150

159151
# If big graph and the outputs are scalar, do not make it inplace.
160152
if large_graph and all(node.outputs[0].type.broadcastable):
161153
continue
162154

163-
candidate_inputs = [
164-
(node.inputs.index(inp), inp)
165-
for inp in inplace_candidates(
166-
fgraph,
167-
node.inputs,
168-
protected_inputs=protected_inputs,
169-
)
170-
]
171-
if not candidate_inputs:
172-
return []
173-
174-
candidate_pairs = [
175-
((o, out), (i, inp))
176-
for o, out in enumerate(node.outputs)
177-
for i, inp in candidate_inputs
178-
if inp.type == out.type
179-
]
155+
candidate_pairs = self.filter_candidate_pairs(
156+
fgraph, node, protected_inputs
157+
)
180158

181159
if not candidate_pairs:
182160
continue
@@ -216,7 +194,7 @@ def create_inplace_node(node, inplace_pattern):
216194
inplace_pattern[o] = [i]
217195
tried_inputs.add(i)
218196

219-
inplace_node = create_inplace_node(node, inplace_pattern)
197+
inplace_node = self.create_inplace_node(node, inplace_pattern)
220198
if inplace_node.op.destroy_map == inplace_pattern:
221199
replacements = tuple(zip(node.outputs, inplace_node.outputs))
222200
try:
@@ -238,7 +216,7 @@ def create_inplace_node(node, inplace_pattern):
238216
inplace_pattern[o] = [i]
239217
tried_inputs.add(i)
240218

241-
inplace_node = create_inplace_node(node, inplace_pattern)
219+
inplace_node = self.create_inplace_node(node, inplace_pattern)
242220
if inplace_node.op.destroy_map != inplace_pattern:
243221
# This Op can't respect this partial inplace pattern,
244222
# We assume it can't support any other cases
@@ -249,6 +227,7 @@ def create_inplace_node(node, inplace_pattern):
249227
fgraph.replace_all_validate(
250228
replacements, reason="inplace_elemwise_optimizer"
251229
)
230+
node = inplace_node
252231
replaced = True
253232
node = inplace_node
254233
except InconsistencyError:
@@ -278,6 +257,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
278257
)
279258

280259

260+
class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
261+
op = Elemwise
262+
263+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
264+
candidate_inputs = [
265+
(node.inputs.index(inp), inp)
266+
for inp in inplace_candidates(
267+
fgraph,
268+
node.inputs,
269+
protected_inputs=protected_inputs,
270+
)
271+
]
272+
if not candidate_inputs:
273+
return []
274+
275+
return [
276+
((o, out), (i, inp))
277+
for o, out in enumerate(node.outputs)
278+
for i, inp in candidate_inputs
279+
if inp.type == out.type
280+
]
281+
282+
def create_inplace_node(self, node, inplace_pattern):
283+
op = node.op
284+
scalar_op = op.scalar_op
285+
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
286+
if hasattr(scalar_op, "make_new_inplace"):
287+
new_scalar_op = scalar_op.make_new_inplace(
288+
ps.transfer_type(
289+
*[
290+
inplace_pattern.get(i, o.dtype)
291+
for i, o in enumerate(node.outputs)
292+
]
293+
)
294+
)
295+
else:
296+
new_scalar_op = type(scalar_op)(
297+
ps.transfer_type(
298+
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
299+
)
300+
)
301+
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
302+
303+
281304
compile.optdb.register(
282305
"inplace_elemwise",
283306
InplaceElemwiseOptimizer(),

0 commit comments

Comments
 (0)