Skip to content

Commit b05557e

Browse files
committed
Reuse Elemwise inplace machinery for Blockwise
1 parent 070f5e6 commit b05557e

File tree

3 files changed

+195
-100
lines changed

3 files changed

+195
-100
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.graph import Constant, node_rewriter
33
from pytensor.graph.destroyhandler import inplace_candidates
44
from pytensor.graph.replace import vectorize_node
5-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
5+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
66
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.math import Dot
@@ -11,6 +11,7 @@
1111
register_specialize,
1212
register_stabilize,
1313
)
14+
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
1415
from pytensor.tensor.shape import Reshape
1516
from pytensor.tensor.subtensor import (
1617
AdvancedIncSubtensor,
@@ -260,68 +261,76 @@ def local_blockwise_of_subtensor(fgraph, node):
260261
return [x[(*none_slices, *core_idxs)]]
261262

262263

263-
@node_rewriter(tracks=[Blockwise], inplace=True)
264-
def blockwise_inplace(fgraph, node):
265-
blockwise_op = node.op
264+
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
265+
op = Blockwise
266266

267-
if blockwise_op.destroy_map:
268-
# Op already has inplace
269-
return
270-
271-
# Find out valid inputs for inplacing
272-
batch_ndim = blockwise_op.batch_ndim(node)
273-
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
267+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
268+
blockwise_op = node.op
269+
batch_ndim = blockwise_op.batch_ndim(node)
270+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
271+
inputs = node.inputs
274272

275-
inputs = node.inputs
276-
candidate_inputs = set(
277-
inplace_candidates(
278-
fgraph,
279-
[
280-
inp
281-
for inp in inputs
282-
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
283-
],
273+
candidate_inputs = set(
274+
inplace_candidates(
275+
fgraph,
276+
[
277+
inp
278+
for inp in inputs
279+
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
280+
],
281+
)
284282
)
285-
)
286-
allowed_inplace_inputs = [
287-
i for i, inp in enumerate(inputs) if inp in candidate_inputs
288-
]
289283

290-
if not allowed_inplace_inputs:
291-
return None
284+
allowed_inplace_inputs = [
285+
i for i, inp in enumerate(inputs) if inp in candidate_inputs
286+
]
287+
destroy_map = blockwise_op.core_op.inplace_on_inputs(
288+
allowed_inplace_inputs=allowed_inplace_inputs
289+
).destroy_map
290+
291+
if not destroy_map:
292+
return []
293+
294+
outputs = node.outputs
295+
return [
296+
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
297+
for out_idx, inp_idxs in destroy_map.items()
298+
for inp_idx in inp_idxs
299+
]
292300

293-
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
294-
allowed_inplace_inputs=allowed_inplace_inputs
295-
)
301+
def create_inplace_node(self, node, inplace_pattern):
302+
blockwise_op = node.op
303+
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
304+
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
305+
allowed_inplace_inputs=allowed_inplace_inputs
306+
)
296307

297-
if not inplace_core_op.destroy_map:
298-
return None
308+
if not inplace_core_op.destroy_map:
309+
return node
299310

300-
# Check Op is not trying to inplace on non-candidate inputs
301-
for destroyed_inputs in inplace_core_op.destroy_map.values():
302-
for destroyed_input in destroyed_inputs:
303-
if destroyed_input not in allowed_inplace_inputs:
304-
raise ValueError(
305-
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
306-
)
311+
# Check Op is not trying to inplace on non-candidate inputs
312+
for destroyed_inputs in inplace_core_op.destroy_map.values():
313+
for destroyed_input in destroyed_inputs:
314+
if destroyed_input not in allowed_inplace_inputs:
315+
raise ValueError(
316+
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
317+
)
307318

308-
# Recreate core_op with inplace
309-
inplace_blockwise_op = Blockwise(
310-
core_op=inplace_core_op,
311-
signature=blockwise_op.signature,
312-
name=blockwise_op.name,
313-
gufunc_spec=blockwise_op.gufunc_spec,
314-
destroy_map=inplace_core_op.destroy_map,
315-
)
319+
# Recreate core_op with inplace
320+
inplace_blockwise_op = type(blockwise_op)(
321+
core_op=inplace_core_op,
322+
signature=blockwise_op.signature,
323+
name=blockwise_op.name,
324+
gufunc_spec=blockwise_op.gufunc_spec,
325+
destroy_map=inplace_core_op.destroy_map,
326+
)
316327

317-
out = inplace_blockwise_op.make_node(*node.inputs).outputs
318-
copy_stack_trace(node.outputs, out)
319-
return out
328+
return inplace_blockwise_op.make_node(*node.inputs)
320329

321330

322331
optdb.register(
323332
"blockwise_inplace",
324-
in2out(blockwise_inplace),
333+
InplaceBlockwiseOptimizer(),
325334
"fast_run",
326335
"inplace",
327336
position=50.1,

pytensor/tensor/rewriting/elemwise.py

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import abc
12
import itertools
23
import operator
34
import sys
45
from collections import defaultdict, deque
5-
from collections.abc import Generator
6+
from collections.abc import Generator, Sequence
67
from functools import cache, reduce
78
from typing import TypeVar
89
from warnings import warn
@@ -12,7 +13,7 @@
1213
from pytensor.compile.function.types import Supervisor
1314
from pytensor.compile.mode import get_target_language
1415
from pytensor.configdefaults import config
15-
from pytensor.graph import FunctionGraph
16+
from pytensor.graph import FunctionGraph, Op
1617
from pytensor.graph.basic import Apply, Variable, ancestors
1718
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1819
from pytensor.graph.features import ReplaceValidate
@@ -47,14 +48,28 @@
4748
from pytensor.tensor.variable import TensorConstant, TensorVariable
4849

4950

50-
class InplaceElemwiseOptimizer(GraphRewriter):
51+
class InplaceGraphOptimizer(GraphRewriter):
5152
r"""
5253
This is parameterized so that it works for `Elemwise` `Op`\s.
5354
"""
5455

56+
op: type[Op]
57+
5558
def add_requirements(self, fgraph):
5659
fgraph.attach_feature(DestroyHandler())
5760

61+
@abc.abstractmethod
62+
def filter_candidate_pairs(
63+
self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable]
64+
) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]:
65+
pass
66+
67+
@abc.abstractmethod
68+
def create_inplace_node(
69+
self, node: Apply, inplace_pattern: dict[int, Sequence[int]]
70+
) -> Apply:
71+
pass
72+
5873
def apply(self, fgraph):
5974
r"""
6075
@@ -93,30 +108,6 @@ def apply(self, fgraph):
93108
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
94109
# We can consider restricting inputs with static shapes that are large enough.
95110

96-
def create_inplace_node(node, inplace_pattern):
97-
op = node.op
98-
scalar_op = op.scalar_op
99-
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
100-
if hasattr(scalar_op, "make_new_inplace"):
101-
new_scalar_op = scalar_op.make_new_inplace(
102-
ps.transfer_type(
103-
*[
104-
inplace_pattern.get(i, o.dtype)
105-
for i, o in enumerate(node.outputs)
106-
]
107-
)
108-
)
109-
else:
110-
new_scalar_op = type(scalar_op)(
111-
ps.transfer_type(
112-
*[
113-
inplace_pattern.get(i, None)
114-
for i in range(len(node.outputs))
115-
]
116-
)
117-
)
118-
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
119-
120111
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
121112
warn(
122113
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
@@ -140,43 +131,30 @@ def create_inplace_node(node, inplace_pattern):
140131
protected_inputs.update(fgraph.outputs)
141132
root_destroyer = fgraph.destroy_handler.root_destroyer
142133

134+
self_op = self.op
143135
update_mapping = fgraph.update_mapping or {}
144136
op_updates: dict[TensorVariable, TensorVariable] = {
145137
out: fgraph.inputs[update_mapping[out_idx]]
146138
for out_idx, out in enumerate(fgraph.outputs)
147139
if (
148140
out_idx in update_mapping
149141
and out.owner
150-
and isinstance(out.owner.op, Elemwise)
142+
and isinstance(out.owner.op, self_op)
151143
)
152144
}
153145
set_op_updates = set(op_updates.keys())
154146

155147
for node in fgraph.toposort():
156-
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
148+
if not isinstance(node.op, self_op) or node.op.destroy_map:
157149
continue
158150

159151
# If big graph and the outputs are scalar, do not make it inplace.
160152
if large_graph and all(node.outputs[0].type.broadcastable):
161153
continue
162154

163-
candidate_inputs = [
164-
(node.inputs.index(inp), inp)
165-
for inp in inplace_candidates(
166-
fgraph,
167-
node.inputs,
168-
protected_inputs=protected_inputs,
169-
)
170-
]
171-
if not candidate_inputs:
172-
return []
173-
174-
candidate_pairs = [
175-
((o, out), (i, inp))
176-
for o, out in enumerate(node.outputs)
177-
for i, inp in candidate_inputs
178-
if inp.type == out.type
179-
]
155+
candidate_pairs = self.filter_candidate_pairs(
156+
fgraph, node, protected_inputs
157+
)
180158

181159
if not candidate_pairs:
182160
continue
@@ -216,7 +194,7 @@ def create_inplace_node(node, inplace_pattern):
216194
inplace_pattern[o] = [i]
217195
tried_inputs.add(i)
218196

219-
inplace_node = create_inplace_node(node, inplace_pattern)
197+
inplace_node = self.create_inplace_node(node, inplace_pattern)
220198
if inplace_node.op.destroy_map == inplace_pattern:
221199
replacements = tuple(zip(node.outputs, inplace_node.outputs))
222200
try:
@@ -238,7 +216,7 @@ def create_inplace_node(node, inplace_pattern):
238216
inplace_pattern[o] = [i]
239217
tried_inputs.add(i)
240218

241-
inplace_node = create_inplace_node(node, inplace_pattern)
219+
inplace_node = self.create_inplace_node(node, inplace_pattern)
242220
if inplace_node.op.destroy_map != inplace_pattern:
243221
# This Op can't respect this partial inplace pattern,
244222
# We assume it can't support any other cases
@@ -277,6 +255,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
277255
)
278256

279257

258+
class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
259+
op = Elemwise
260+
261+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
262+
candidate_inputs = [
263+
(node.inputs.index(inp), inp)
264+
for inp in inplace_candidates(
265+
fgraph,
266+
node.inputs,
267+
protected_inputs=protected_inputs,
268+
)
269+
]
270+
if not candidate_inputs:
271+
return []
272+
273+
return [
274+
((o, out), (i, inp))
275+
for o, out in enumerate(node.outputs)
276+
for i, inp in candidate_inputs
277+
if inp.type == out.type
278+
]
279+
280+
def create_inplace_node(self, node, inplace_pattern):
281+
op = node.op
282+
scalar_op = op.scalar_op
283+
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
284+
if hasattr(scalar_op, "make_new_inplace"):
285+
new_scalar_op = scalar_op.make_new_inplace(
286+
ps.transfer_type(
287+
*[
288+
inplace_pattern.get(i, o.dtype)
289+
for i, o in enumerate(node.outputs)
290+
]
291+
)
292+
)
293+
else:
294+
new_scalar_op = type(scalar_op)(
295+
ps.transfer_type(
296+
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
297+
)
298+
)
299+
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
300+
301+
280302
compile.optdb.register(
281303
"inplace_elemwise",
282304
InplaceElemwiseOptimizer(),

0 commit comments

Comments
 (0)