Skip to content

Commit de5aead

Browse files
committed
Implement concat for XTensorVariables
1 parent 095eef4 commit de5aead

File tree

4 files changed

+164
-3
lines changed

4 files changed

+164
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
linalg,
66
special,
77
)
8+
from pytensor.xtensor.shape import concat
89
from pytensor.xtensor.type import (
910
XTensorType,
1011
as_xtensor,

pytensor/xtensor/rewriting/shape.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import moveaxis
2+
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Stack
5+
from pytensor.xtensor.shape import Concat, Stack
66

77

88
@register_lower_xtensor
@@ -27,3 +27,46 @@ def lower_stack(fgraph, node):
2727

2828
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
2929
return [new_out]
30+
31+
32+
@register_lower_xtensor
33+
@node_rewriter(tracks=[Concat])
34+
def lower_concat(fgraph, node):
35+
out_dims = node.outputs[0].type.dims
36+
concat_dim = node.op.dim
37+
concat_axis = out_dims.index(concat_dim)
38+
39+
# Convert input XTensors to Tensors and align batch dimensions
40+
tensor_inputs = []
41+
for inp in node.inputs:
42+
inp_dims = inp.type.dims
43+
order = [
44+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
45+
for out_dim in out_dims
46+
]
47+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
48+
tensor_inputs.append(tensor_inp)
49+
50+
# Broadcast non-concatenated dimensions of each input
51+
non_concat_shape = [None] * len(out_dims)
52+
for tensor_inp in tensor_inputs:
53+
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
54+
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
55+
for i, (bcast, sh) in enumerate(
56+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
57+
):
58+
if bcast or i == concat_axis or non_concat_shape[i] is not None:
59+
continue
60+
non_concat_shape[i] = sh
61+
62+
assert non_concat_shape.count(None) == 1
63+
64+
bcast_tensor_inputs = []
65+
for tensor_inp in tensor_inputs:
66+
# We modify the concat_axis in place, as we don't need the list anywhere else
67+
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
68+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
69+
70+
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
71+
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
72+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Sequence
22

33
from pytensor.graph import Apply
4+
from pytensor.scalar import upcast
45
from pytensor.xtensor.basic import XOp
56
from pytensor.xtensor.type import as_xtensor, xtensor
67

@@ -69,3 +70,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
6970
)
7071
y = Stack(new_dim_name, tuple(stacked_dims))(y)
7172
return y
73+
74+
75+
class Concat(XOp):
76+
__props__ = ("dim",)
77+
78+
def __init__(self, dim: str):
79+
self.dim = dim
80+
super().__init__()
81+
82+
def make_node(self, *inputs):
83+
inputs = [as_xtensor(inp) for inp in inputs]
84+
concat_dim = self.dim
85+
86+
dims_and_shape: dict[str, int | None] = {}
87+
for inp in inputs:
88+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
89+
if dim not in dims_and_shape:
90+
dims_and_shape[dim] = dim_length
91+
else:
92+
if dim == concat_dim:
93+
if dim_length is None:
94+
dims_and_shape[dim] = None
95+
elif dims_and_shape[dim] is not None:
96+
dims_and_shape[dim] += dim_length
97+
elif dim_length is not None:
98+
# Check for conflicting in non-concatenated shapes
99+
if (dims_and_shape[dim] is not None) and (
100+
dims_and_shape[dim] != dim_length
101+
):
102+
raise ValueError(
103+
f"Non-concatenated dimension {dim} has conflicting shapes"
104+
)
105+
# Keep the non-None shape
106+
dims_and_shape[dim] = dim_length
107+
108+
if concat_dim not in dims_and_shape:
109+
# It's a new dim, that should be located at the start
110+
dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape
111+
elif dims_and_shape[concat_dim] is not None:
112+
# We need to add +1 for every input that doesn't have this dimension
113+
for inp in inputs:
114+
if concat_dim not in inp.type.dims:
115+
dims_and_shape[concat_dim] += 1
116+
117+
dims, shape = zip(*dims_and_shape.items())
118+
dtype = upcast(*[x.type.dtype for x in inputs])
119+
output = xtensor(dtype=dtype, dims=dims, shape=shape)
120+
return Apply(self, inputs, [output])
121+
122+
123+
def concat(xtensors, dim: str):
124+
return Concat(dim=dim)(*xtensors)

tests/xtensor/test_shape.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
from itertools import chain, combinations
88

9-
from pytensor.xtensor.shape import stack
9+
import numpy as np
10+
from xarray import concat as xr_concat
11+
12+
from pytensor.xtensor.shape import concat, stack
1013
from pytensor.xtensor.type import xtensor
1114
from tests.xtensor.util import (
1215
xr_arange_like,
1316
xr_assert_allclose,
1417
xr_function,
18+
xr_random_like,
1519
)
1620

1721

@@ -65,3 +69,63 @@ def test_multiple_stacks():
6569
res = fn(x_test)
6670
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
6771
xr_assert_allclose(res[0], expected_res)
72+
73+
74+
@pytest.mark.parametrize("dim", ("a", "b", "new"))
75+
def test_concat(dim):
76+
rng = np.random.default_rng(sum(map(ord, dim)))
77+
78+
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
79+
x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2))
80+
81+
x3_shape0 = 4 if dim == "a" else 2
82+
x3_shape1 = 5 if dim == "b" else 3
83+
x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1))
84+
85+
out = concat([x1, x2, x3], dim=dim)
86+
87+
fn = xr_function([x1, x2, x3], out)
88+
x1_test = xr_random_like(x1, rng)
89+
x2_test = xr_random_like(x2, rng)
90+
x3_test = xr_random_like(x3, rng)
91+
92+
res = fn(x1_test, x2_test, x3_test)
93+
expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim)
94+
xr_assert_allclose(res, expected_res)
95+
96+
97+
@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new"))
98+
def test_concat_with_broadcast(dim):
99+
rng = np.random.default_rng(sum(map(ord, dim)) + 1)
100+
101+
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
102+
x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5))
103+
x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7))
104+
x4 = xtensor("x4", dims=(), shape=())
105+
106+
out = concat([x1, x2, x3, x4], dim=dim)
107+
108+
fn = xr_function([x1, x2, x3, x4], out)
109+
110+
x1_test = xr_random_like(x1, rng)
111+
x2_test = xr_random_like(x2, rng)
112+
x3_test = xr_random_like(x3, rng)
113+
x4_test = xr_random_like(x4, rng)
114+
res = fn(x1_test, x2_test, x3_test, x4_test)
115+
expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
116+
xr_assert_allclose(res, expected_res)
117+
118+
119+
def test_concat_scalar():
120+
x1 = xtensor("x1", dims=(), shape=())
121+
x2 = xtensor("x2", dims=(), shape=())
122+
123+
out = concat([x1, x2], dim="new_dim")
124+
125+
fn = xr_function([x1, x2], out)
126+
127+
x1_test = xr_random_like(x1)
128+
x2_test = xr_random_like(x2)
129+
res = fn(x1_test, x2_test)
130+
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
131+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)