Skip to content

Commit 5c62008

Browse files
OriolAbrilricardoV94
authored andcommitted
Implement unstack for XTensorVariables
1 parent 1cb5289 commit 5c62008

File tree

4 files changed

+155
-4
lines changed

4 files changed

+155
-4
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import broadcast_to, join, moveaxis
2+
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
66

77

88
@register_lower_xtensor
@@ -29,6 +29,25 @@ def lower_stack(fgraph, node):
2929
return [new_out]
3030

3131

32+
@register_lower_xtensor
33+
@node_rewriter(tracks=[UnStack])
34+
def lower_unstack(fgraph, node):
35+
x = node.inputs[0]
36+
unstacked_lengths = node.inputs[1:]
37+
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)
38+
39+
x_tensor = tensor_from_xtensor(x)
40+
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
41+
final_tensor = x_tensor_transposed.reshape(
42+
(*x_tensor_transposed.shape[:-1], *unstacked_lengths)
43+
)
44+
# Reintroduce any static shape information that was lost during the reshape
45+
final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape)
46+
47+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
48+
return [new_out]
49+
50+
3251
@register_lower_xtensor
3352
@node_rewriter(tracks=[Concat])
3453
def lower_concat(fgraph, node):

pytensor/xtensor/shape.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from typing import Literal
66

77
from pytensor.graph import Apply
8-
from pytensor.scalar import upcast
8+
from pytensor.scalar import discrete_dtypes, upcast
9+
from pytensor.tensor import as_tensor, get_scalar_constant_value
10+
from pytensor.tensor.exceptions import NotScalarConstantError
911
from pytensor.xtensor.basic import XOp
1012
from pytensor.xtensor.type import as_xtensor, xtensor
1113

@@ -76,6 +78,89 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7678
return y
7779

7880

81+
class UnStack(XOp):
82+
__props__ = ("old_dim_name", "unstacked_dims")
83+
84+
def __init__(
85+
self,
86+
old_dim_name: str,
87+
unstacked_dims: tuple[str, ...],
88+
):
89+
super().__init__()
90+
if old_dim_name in unstacked_dims:
91+
raise ValueError(
92+
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
93+
)
94+
if not unstacked_dims:
95+
raise ValueError("Dims to unstack into can't be empty.")
96+
if len(unstacked_dims) == 1:
97+
raise ValueError("Only one dimension to unstack into, use rename instead")
98+
self.old_dim_name = old_dim_name
99+
self.unstacked_dims = unstacked_dims
100+
101+
def make_node(self, x, *unstacked_length):
102+
x = as_xtensor(x)
103+
if self.old_dim_name not in x.type.dims:
104+
raise ValueError(
105+
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
106+
)
107+
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
108+
raise ValueError(
109+
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
110+
)
111+
112+
if len(unstacked_length) != len(self.unstacked_dims):
113+
raise ValueError(
114+
f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
115+
)
116+
unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length]
117+
if not all(length.dtype in discrete_dtypes for length in unstacked_lengths):
118+
raise TypeError("Unstacked lengths must be discrete dtypes.")
119+
120+
if x.type.ndim == 1:
121+
batch_dims, batch_shape = (), ()
122+
else:
123+
batch_dims, batch_shape = zip(
124+
*(
125+
(dim, shape)
126+
for dim, shape in zip(x.type.dims, x.type.shape)
127+
if dim != self.old_dim_name
128+
)
129+
)
130+
131+
static_unstacked_lengths = [None] * len(unstacked_lengths)
132+
for i, length in enumerate(unstacked_lengths):
133+
try:
134+
static_length = get_scalar_constant_value(length)
135+
except NotScalarConstantError:
136+
pass
137+
else:
138+
static_unstacked_lengths[i] = int(static_length)
139+
140+
output = xtensor(
141+
dtype=x.type.dtype,
142+
shape=(*batch_shape, *static_unstacked_lengths),
143+
dims=(*batch_dims, *self.unstacked_dims),
144+
)
145+
return Apply(self, [x, *unstacked_lengths], [output])
146+
147+
148+
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
149+
if dim is not None:
150+
if dims:
151+
raise ValueError(
152+
"Cannot use both positional dim and keyword dims in unstack"
153+
)
154+
dims = dim
155+
156+
y = x
157+
for old_dim_name, unstacked_dict in dims.items():
158+
y = UnStack(old_dim_name, tuple(unstacked_dict.keys()))(
159+
y, *tuple(unstacked_dict.values())
160+
)
161+
return y
162+
163+
79164
class Transpose(XOp):
80165
__props__ = ("dims",)
81166

pytensor/xtensor/type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ def transpose(
523523
def stack(self, dim, **dims):
524524
return px.shape.stack(self, dim, **dims)
525525

526+
def unstack(self, dim, **dims):
527+
return px.shape.unstack(self, dim, **dims)
528+
526529

527530
class XTensorConstantSignature(TensorConstantSignature):
528531
pass

tests/xtensor/test_shape.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11+
from xarray import DataArray
1112
from xarray import concat as xr_concat
1213

13-
from pytensor.xtensor.shape import concat, stack, transpose
14+
from pytensor.xtensor.shape import concat, stack, transpose, unstack
1415
from pytensor.xtensor.type import xtensor
1516
from tests.xtensor.util import (
1617
xr_arange_like,
@@ -154,6 +155,49 @@ def test_multiple_stacks():
154155
xr_assert_allclose(res[0], expected_res)
155156

156157

158+
def test_unstack_constant_size():
159+
x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7))
160+
y = unstack(x, bc=dict(b=3, c=5))
161+
assert y.type.dims == ("a", "d", "b", "c")
162+
assert y.type.shape == (2, 7, 3, 5)
163+
164+
fn = xr_function([x], y)
165+
166+
x_test = xr_arange_like(x)
167+
x_np = x_test.values
168+
res = fn(x_test)
169+
expected = (
170+
DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d"))
171+
.stack(bc=("b", "c"))
172+
.unstack("bc")
173+
)
174+
xr_assert_allclose(res, expected)
175+
176+
177+
def test_unstack_symbolic_size():
178+
x = xtensor(dims=("a", "b", "c"))
179+
y = stack(x, bc=("b", "c"))
180+
y = y / y.sum("bc")
181+
z = unstack(y, bc={"b": x.sizes["b"], "c": x.sizes["c"]})
182+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
183+
fn = xr_function([x], z)
184+
res = fn(x_test)
185+
expected_res = x_test / x_test.sum(["b", "c"])
186+
xr_assert_allclose(res, expected_res)
187+
188+
189+
def test_stack_unstack():
190+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
191+
stack_x = stack(x, bd=("b", "d"))
192+
unstack_x = unstack(stack_x, bd=dict(b=3, d=7))
193+
194+
x_test = xr_arange_like(x)
195+
fn = xr_function([x], unstack_x)
196+
res = fn(x_test)
197+
expected_res = x_test.transpose("a", "c", "b", "d")
198+
xr_assert_allclose(res, expected_res)
199+
200+
157201
@pytest.mark.parametrize("dim", ("a", "b", "new"))
158202
def test_concat(dim):
159203
rng = np.random.default_rng(sum(map(ord, dim)))

0 commit comments

Comments
 (0)