-
Notifications
You must be signed in to change notification settings - Fork 139
Description
We have some other rewrites that will push Alloc below Elemwise, so that we don't compute on repeated inputs, but this won't happen if there's an expand_dims in the way. As of now the following graph does not get lifted
import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(3,))
y = pt.alloc(x, 1000, 3)[None]
out = pt.exp(y)
pytensor.function([x], out).dprint(print_type=True)
# Exp [id A] <Tensor3(float64, shape=(1, 1000, 3))> 1
# └─ Alloc [id B] <Tensor3(float64, shape=(1, 1000, 3))> 0
# ├─ x [id C] <Vector(float64, shape=(3,))>
# ├─ 1 [id D] <Scalar(int8, shape=())>
# ├─ 1000 [id E] <Scalar(int16, shape=())>
# └─ 3 [id F] <Scalar(int8, shape=())>
There is actually an "uncanonicalize" rewrite that allows "lifting" expand_dims above some Alloc
, which would have helped here.
pytensor/pytensor/tensor/rewriting/uncanonicalize.py
Lines 125 to 150 in e6e6d69
@register_uncanonicalize | |
@node_rewriter([DimShuffle]) | |
def local_dimshuffle_alloc(fgraph, node): | |
""" | |
If an alloc is inside a dimshuffle which only adds dimension to the left, | |
scrap the dimshuffle and adds 1 into the alloc | |
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) | |
""" | |
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: | |
input_ = node.inputs[0] | |
if isinstance(input_.owner.op, Alloc): | |
# check if it only adds dimension to the left | |
new_order = node.op.new_order | |
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( | |
range(input_.ndim) | |
) | |
if new_order != expected_new_order: | |
return False | |
# count numbers of 'x' | |
nb_new_dims = len(new_order) - input_.ndim | |
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) | |
return [alloc(input_.owner.inputs[0], *new_shape_input)] | |
return False |
However, this is at odds with the opposite canonical local_alloc_sink_dimshuffle
:
pytensor/pytensor/tensor/rewriting/basic.py
Lines 462 to 467 in d62f4b1
@register_specialize | |
@register_stabilize | |
@register_canonicalize | |
@node_rewriter([Alloc]) | |
def local_alloc_sink_dimshuffle(fgraph, node): | |
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s.""" |
It's not obvious to me why the latter should be given preference. In general it seems like we can always lift expand_dims towards the inputs of the function (as it does not affect number of operations), and sink alloc towards the outputs. But here we are not allowing the "swap" when an expand_dims meets an alloc