Skip to content

Commit fdc5153

Browse files
committed
Do not skip validation between consecutive Elemwise inplace replacements
1 parent 7886cf8 commit fdc5153

File tree

4 files changed

+202
-259
lines changed

4 files changed

+202
-259
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import itertools
88
from collections import deque
99

10-
import pytensor
1110
from pytensor.configdefaults import config
1211
from pytensor.graph.basic import Constant
1312
from pytensor.graph.features import AlreadyThere, Bookkeeper
@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
223222
return droot, impact, root_destroyer
224223

225224

226-
def fast_inplace_check(fgraph, inputs):
225+
def inplace_candidates(fgraph, inputs, protected_inputs=None):
227226
"""
228227
Return the variables in inputs that are possible candidate for as inputs of
229228
inplace operation.
@@ -234,22 +233,40 @@ def fast_inplace_check(fgraph, inputs):
234233
Inputs Variable that you want to use as inplace destination.
235234
236235
"""
237-
Supervisor = pytensor.compile.function.types.Supervisor
238-
protected_inputs = list(
239-
itertools.chain.from_iterable(
240-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
236+
237+
def view_is_protected(view_dict, variable, protected):
238+
if variable in protected:
239+
return True
240+
try:
241+
while True:
242+
variable = view_dict[variable]
243+
if variable in protected:
244+
return True
245+
except KeyError:
246+
return False
247+
248+
if protected_inputs is None:
249+
from pytensor.compile.function.types import Supervisor
250+
251+
protected_inputs = set(
252+
itertools.chain.from_iterable(
253+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
254+
)
255+
)
256+
protected_inputs.update(fgraph.outputs)
257+
258+
has_destroyers = fgraph.has_destroyers
259+
view_i = fgraph.destroy_handler.view_i
260+
261+
return [
262+
inp
263+
for inp in inputs
264+
if (
265+
not isinstance(inp, Constant)
266+
and not view_is_protected(view_i, inp, protected_inputs)
267+
and not has_destroyers([inp])
241268
)
242-
)
243-
protected_inputs.extend(fgraph.outputs)
244-
245-
inputs = [
246-
i
247-
for i in inputs
248-
if not isinstance(i, Constant)
249-
and not fgraph.has_destroyers([i])
250-
and i not in protected_inputs
251269
]
252-
return inputs
253270

254271

255272
class DestroyHandler(Bookkeeper):

pytensor/tensor/rewriting/blockwise.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import itertools
2-
3-
from pytensor.compile import Supervisor
41
from pytensor.compile.mode import optdb
52
from pytensor.graph import Constant, node_rewriter
3+
from pytensor.graph.destroyhandler import inplace_candidates
64
from pytensor.graph.replace import vectorize_node
75
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
86
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
@@ -270,28 +268,20 @@ def blockwise_inplace(fgraph, node):
270268
# Op already has inplace
271269
return
272270

273-
# Find out valid inputs for inplacing
271+
# Find out valid inputs for in-placing
274272
batch_ndim = blockwise_op.batch_ndim(node)
275273
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
274+
inputs = node.inputs
276275

277-
protected_inputs = [
278-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
279-
]
280-
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
281-
protected_inputs.extend(fgraph.outputs)
282276
allowed_inplace_inputs = [
283-
idx
284-
for idx, inp in enumerate(node.inputs)
285-
if
286-
(
287-
# Constants would need to be recreated every time if inplaced
288-
not isinstance(inp, Constant)
289-
# We can only inplace on inputs that are not being broadcasted
290-
# As those are reused across iterations of Blockwise
291-
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
292-
# Inputs that are marked as protected or destroyed can't be inplaced
293-
and not fgraph.has_destroyers([inp])
294-
and inp not in protected_inputs
277+
inputs.index(inp)
278+
for inp in inplace_candidates(
279+
fgraph,
280+
[
281+
inp
282+
for inp in inputs
283+
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
284+
],
295285
)
296286
]
297287

0 commit comments

Comments
 (0)