Skip to content

Do not skip validation between consecutive Elemwise inplace replacements #1494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import itertools
from collections import deque

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
Expand Down Expand Up @@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
return droot, impact, root_destroyer


def fast_inplace_check(fgraph, inputs):
def inplace_candidates(fgraph, inputs, protected_inputs=None):
"""
Return the variables in inputs that are possible candidate for as inputs of
inplace operation.
Expand All @@ -234,22 +233,49 @@ def fast_inplace_check(fgraph, inputs):
Inputs Variable that you want to use as inplace destination.

"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
if protected_inputs is None:
from pytensor.compile.function.types import Supervisor

protected_inputs = set(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
)
protected_inputs.extend(fgraph.outputs)

inputs = [
i
for i in inputs
if not isinstance(i, Constant)
and not fgraph.has_destroyers([i])
and i not in protected_inputs
]
return inputs
protected_inputs.update(fgraph.outputs)

has_destroyers = fgraph.has_destroyers
view_i = fgraph.destroy_handler.view_i
candidate_roots = {}
candidate_inputs = []
for inp in inputs:
if isinstance(inp, Constant):
# Can't inplace on constants.
continue

# Find the root of the view chain, and while traversing check if it passes on any protected inputs.
view_of_protected = False
root = inp
try:
while True:
if root in protected_inputs:
view_of_protected = True
root = view_i[root]
except KeyError:
pass

if root in candidate_roots:
# Another input views on the same root, we can't destroy either
if (invalid_candidate := candidate_roots[root]) is not None:
# Invalidate the previous candidate
candidate_inputs.remove(invalid_candidate)
candidate_roots[root] = None
elif not view_of_protected and not has_destroyers([inp]):
candidate_inputs.append(inp)
candidate_roots[root] = inp
else:
candidate_roots[root] = None

return candidate_inputs


class DestroyHandler(Bookkeeper):
Expand Down
119 changes: 60 additions & 59 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import itertools

from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
Expand All @@ -13,6 +11,7 @@
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -262,74 +261,76 @@
return [x[(*none_slices, *core_idxs)]]


@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
op = Blockwise

if blockwise_op.destroy_map:
# Op already has inplace
return
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
blockwise_op = node.op
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
inputs = node.inputs

# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]

protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
candidate_inputs = set(
inplace_candidates(
fgraph,
[
inp
for inp in inputs
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
)
)
]

if not allowed_inplace_inputs:
return None
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
destroy_map = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
).destroy_map

if not destroy_map:
return []

outputs = node.outputs
return [
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
for out_idx, inp_idxs in destroy_map.items()
for inp_idx in inp_idxs
]

inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
def create_inplace_node(self, node, inplace_pattern):
blockwise_op = node.op
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)

if not inplace_core_op.destroy_map:
return None
if not inplace_core_op.destroy_map:
return node

Check warning on line 309 in pytensor/tensor/rewriting/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blockwise.py#L309

Added line #L309 was not covered by tests

# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(

Check warning on line 315 in pytensor/tensor/rewriting/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blockwise.py#L315

Added line #L315 was not covered by tests
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)

# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
# Recreate core_op with inplace
inplace_blockwise_op = type(blockwise_op)(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)

out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
return inplace_blockwise_op.make_node(*node.inputs)


optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
InplaceBlockwiseOptimizer(),
"fast_run",
"inplace",
position=50.1,
Expand Down
Loading