Skip to content

Commit 38231c7

Browse files
committed
Do not skip validation between consecutive Elemwise inplace replacements
1 parent 236e50d commit 38231c7

File tree

4 files changed

+232
-273
lines changed

4 files changed

+232
-273
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import itertools
88
from collections import deque
99

10-
import pytensor
1110
from pytensor.configdefaults import config
1211
from pytensor.graph.basic import Constant
1312
from pytensor.graph.features import AlreadyThere, Bookkeeper
@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
223222
return droot, impact, root_destroyer
224223

225224

226-
def fast_inplace_check(fgraph, inputs):
225+
def inplace_candidates(fgraph, inputs, protected_inputs=None):
227226
"""
228227
Return the variables in inputs that are possible candidate for as inputs of
229228
inplace operation.
@@ -234,22 +233,28 @@ def fast_inplace_check(fgraph, inputs):
234233
Inputs Variable that you want to use as inplace destination.
235234
236235
"""
237-
Supervisor = pytensor.compile.function.types.Supervisor
238-
protected_inputs = list(
239-
itertools.chain.from_iterable(
240-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
236+
if protected_inputs is None:
237+
from pytensor.compile.function.types import Supervisor
238+
239+
protected_inputs = set(
240+
itertools.chain.from_iterable(
241+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
242+
)
243+
)
244+
protected_inputs.update(fgraph.outputs)
245+
246+
has_destroyers = fgraph.has_destroyers
247+
248+
return [
249+
inp
250+
# Remove duplicates, while preserving order by using dict.fromkeys
251+
for inp in dict.fromkeys(inputs)
252+
if (
253+
not isinstance(inp, Constant)
254+
and inp not in protected_inputs
255+
and not has_destroyers([inp])
241256
)
242-
)
243-
protected_inputs.extend(fgraph.outputs)
244-
245-
inputs = [
246-
i
247-
for i in inputs
248-
if not isinstance(i, Constant)
249-
and not fgraph.has_destroyers([i])
250-
and i not in protected_inputs
251257
]
252-
return inputs
253258

254259

255260
class DestroyHandler(Bookkeeper):

pytensor/tensor/rewriting/blockwise.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import itertools
2-
3-
from pytensor.compile import Supervisor
41
from pytensor.compile.mode import optdb
52
from pytensor.graph import Constant, node_rewriter
3+
from pytensor.graph.destroyhandler import inplace_candidates
64
from pytensor.graph.replace import vectorize_node
75
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
86
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
@@ -274,25 +272,19 @@ def blockwise_inplace(fgraph, node):
274272
batch_ndim = blockwise_op.batch_ndim(node)
275273
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
276274

277-
protected_inputs = [
278-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
279-
]
280-
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
281-
protected_inputs.extend(fgraph.outputs)
282-
allowed_inplace_inputs = [
283-
idx
284-
for idx, inp in enumerate(node.inputs)
285-
if
286-
(
287-
# Constants would need to be recreated every time if inplaced
288-
not isinstance(inp, Constant)
289-
# We can only inplace on inputs that are not being broadcasted
290-
# As those are reused across iterations of Blockwise
291-
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
292-
# Inputs that are marked as protected or destroyed can't be inplaced
293-
and not fgraph.has_destroyers([inp])
294-
and inp not in protected_inputs
275+
inputs = node.inputs
276+
candidate_inputs = set(
277+
inplace_candidates(
278+
fgraph,
279+
[
280+
inp
281+
for inp in inputs
282+
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
283+
],
295284
)
285+
)
286+
allowed_inplace_inputs = [
287+
i for i, inp in enumerate(inputs) if inp in candidate_inputs
296288
]
297289

298290
if not allowed_inplace_inputs:

0 commit comments

Comments
 (0)