Skip to content

Commit 480ef20

Browse files
committed
Less myopic check for protected inputs
1 parent 38231c7 commit 480ef20

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,38 @@ def inplace_candidates(fgraph, inputs, protected_inputs=None):
244244
protected_inputs.update(fgraph.outputs)
245245

246246
has_destroyers = fgraph.has_destroyers
247+
view_i = fgraph.destroy_handler.view_i
248+
candidate_roots = {}
249+
candidate_inputs = []
250+
for inp in inputs:
251+
if isinstance(inp, Constant):
252+
# Can't inplace on constants.
253+
continue
254+
255+
# Find the root of the view chain, and while traversing check if it passes on any protected inputs.
256+
view_of_protected = False
257+
root = inp
258+
try:
259+
while True:
260+
if root in protected_inputs:
261+
view_of_protected = True
262+
root = view_i[root]
263+
except KeyError:
264+
pass
247265

248-
return [
249-
inp
250-
# Remove duplicates, while preserving order by using dict.fromkeys
251-
for inp in dict.fromkeys(inputs)
252-
if (
253-
not isinstance(inp, Constant)
254-
and inp not in protected_inputs
255-
and not has_destroyers([inp])
256-
)
257-
]
266+
if root in candidate_roots:
267+
# Another input views on the same root, we can't destroy either
268+
if (invalid_candidate := candidate_roots[root]) is not None:
269+
# Invalidate the previous candidate
270+
candidate_inputs.remove(invalid_candidate)
271+
candidate_roots[root] = None
272+
elif not view_of_protected and not has_destroyers([inp]):
273+
candidate_inputs.append(inp)
274+
candidate_roots[root] = inp
275+
else:
276+
candidate_roots[root] = None
277+
278+
return candidate_inputs
258279

259280

260281
class DestroyHandler(Bookkeeper):

0 commit comments

Comments
 (0)