Skip to content

Commit 1cb5289

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement transpose for XTensorVariables
1 parent 5442a2d commit 1cb5289

File tree

5 files changed

+240
-9
lines changed

5 files changed

+240
-9
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose
66

77

88
@register_lower_xtensor
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
7070
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
7171
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
7272
return [new_out]
73+
74+
75+
@register_lower_xtensor
76+
@node_rewriter(tracks=[Transpose])
77+
def lower_transpose(fgraph, node):
78+
[x] = node.inputs
79+
# Use the final dimensions that were already computed in make_node
80+
out_dims = node.outputs[0].type.dims
81+
in_dims = x.type.dims
82+
83+
# Compute the permutation based on the final dimensions
84+
perm = tuple(in_dims.index(d) for d in out_dims)
85+
x_tensor = tensor_from_xtensor(x)
86+
x_tensor_transposed = x_tensor.transpose(perm)
87+
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
88+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import typing
2+
import warnings
13
from collections.abc import Sequence
4+
from types import EllipsisType
5+
from typing import Literal
26

37
from pytensor.graph import Apply
48
from pytensor.scalar import upcast
@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7276
return y
7377

7478

79+
class Transpose(XOp):
80+
__props__ = ("dims",)
81+
82+
def __init__(
83+
self,
84+
dims: Sequence[str],
85+
):
86+
super().__init__()
87+
self.dims = tuple(dims)
88+
89+
def make_node(self, x):
90+
x = as_xtensor(x)
91+
92+
transpose_dims = self.dims
93+
x_shape = x.type.shape
94+
x_dims = x.type.dims
95+
if set(transpose_dims) != set(x_dims):
96+
raise ValueError(f"{transpose_dims} must be a permuted list of {x_dims}")
97+
98+
output = xtensor(
99+
dtype=x.type.dtype,
100+
shape=tuple(x_shape[x_dims.index(d)] for d in transpose_dims),
101+
dims=transpose_dims,
102+
)
103+
return Apply(self, [x], [output])
104+
105+
106+
def transpose(
107+
x,
108+
*dims: str | EllipsisType,
109+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
110+
):
111+
"""Transpose dimensions of the tensor.
112+
113+
Parameters
114+
----------
115+
x : XTensorVariable
116+
Input tensor to transpose.
117+
*dims : str
118+
Dimensions to transpose to. Can include ellipsis (...) to represent
119+
remaining dimensions in their original order.
120+
missing_dims : {"raise", "warn", "ignore"}, optional
121+
How to handle dimensions that don't exist in the input tensor:
122+
- "raise": Raise an error if any dimensions don't exist (default)
123+
- "warn": Warn if any dimensions don't exist
124+
- "ignore": Silently ignore any dimensions that don't exist
125+
126+
Returns
127+
-------
128+
XTensorVariable
129+
Transposed tensor with reordered dimensions.
130+
131+
Raises
132+
------
133+
ValueError
134+
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
135+
"""
136+
# Validate dimensions
137+
x = as_xtensor(x)
138+
x_dims = x.type.dims
139+
invalid_dims = set(dims) - {..., *x_dims}
140+
if invalid_dims:
141+
if missing_dims != "ignore":
142+
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
143+
if missing_dims == "raise":
144+
raise ValueError(msg)
145+
else:
146+
warnings.warn(msg)
147+
# Handle missing dimensions if not raising
148+
dims = tuple(d for d in dims if d in x_dims or d is ...)
149+
150+
if dims == () or dims == (...,):
151+
dims = tuple(reversed(x_dims))
152+
elif ... in dims:
153+
if dims.count(...) > 1:
154+
raise ValueError("Ellipsis (...) can only appear once in the dimensions")
155+
# Handle ellipsis expansion
156+
ellipsis_idx = dims.index(...)
157+
pre = dims[:ellipsis_idx]
158+
post = dims[ellipsis_idx + 1 :]
159+
middle = [d for d in x_dims if d not in pre + post]
160+
dims = (*pre, *middle, *post)
161+
162+
return Transpose(typing.cast(tuple[str], dims))(x)
163+
164+
75165
class Concat(XOp):
76166
__props__ = ("dim",)
77167

pytensor/xtensor/type.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing
2+
from types import EllipsisType
23

34
from pytensor.compile import (
45
DeepCopyOp,
@@ -23,7 +24,7 @@
2324
XARRAY_AVAILABLE = False
2425

2526
from collections.abc import Sequence
26-
from typing import TypeVar
27+
from typing import Literal, TypeVar
2728

2829
import numpy as np
2930

@@ -438,6 +439,19 @@ def imag(self):
438439
def real(self):
439440
return px.math.real(self)
440441

442+
@property
443+
def T(self):
444+
"""Return the full transpose of the tensor.
445+
446+
This is equivalent to calling transpose() with no arguments.
447+
448+
Returns
449+
-------
450+
XTensorVariable
451+
Fully transposed tensor.
452+
"""
453+
return self.transpose()
454+
441455
# Aggregation
442456
# https://docs.xarray.dev/en/latest/api.html#id6
443457
def all(self, dim=None):
@@ -475,6 +489,37 @@ def cumprod(self, dim=None):
475489

476490
# Reshaping and reorganizing
477491
# https://docs.xarray.dev/en/latest/api.html#id8
492+
def transpose(
493+
self,
494+
*dims: str | EllipsisType,
495+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
496+
):
497+
"""Transpose dimensions of the tensor.
498+
499+
Parameters
500+
----------
501+
*dims : str | Ellipsis
502+
Dimensions to transpose. If empty, performs a full transpose.
503+
Can use ellipsis (...) to represent remaining dimensions.
504+
missing_dims : {"raise", "warn", "ignore"}, default="raise"
505+
How to handle dimensions that don't exist in the tensor:
506+
- "raise": Raise an error if any dimensions don't exist
507+
- "warn": Warn if any dimensions don't exist
508+
- "ignore": Silently ignore any dimensions that don't exist
509+
510+
Returns
511+
-------
512+
XTensorVariable
513+
Transposed tensor with reordered dimensions.
514+
515+
Raises
516+
------
517+
ValueError
518+
If missing_dims="raise" and any dimensions don't exist.
519+
If multiple ellipsis are provided.
520+
"""
521+
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
522+
478523
def stack(self, dim, **dims):
479524
return px.shape.stack(self, dim, **dims)
480525

tests/xtensor/test_shape.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
pytest.importorskip("xarray")
66

7+
import re
78
from itertools import chain, combinations
89

910
import numpy as np
1011
from xarray import concat as xr_concat
1112

12-
from pytensor.xtensor.shape import concat, stack
13+
from pytensor.xtensor.shape import concat, stack, transpose
1314
from pytensor.xtensor.type import xtensor
1415
from tests.xtensor.util import (
1516
xr_arange_like,
@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
2829
)
2930

3031

32+
def test_transpose():
33+
a, b, c, d, e = "abcde"
34+
35+
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
36+
permutations = [
37+
(a, b, c, d, e), # identity
38+
(e, d, c, b, a), # full tranpose
39+
(), # eqivalent to full transpose
40+
(a, b, c, e, d), # swap last two dims
41+
(..., d, c), # equivalent to (a, b, e, d, c)
42+
(b, a, ..., e, d), # equivalent to (b, a, c, d, e)
43+
(c, a, ...), # equivalent to (c, a, b, d, e)
44+
]
45+
outs = [transpose(x, *perm) for perm in permutations]
46+
47+
fn = xr_function([x], outs)
48+
x_test = xr_arange_like(x)
49+
res = fn(x_test)
50+
expected_res = [x_test.transpose(*perm) for perm in permutations]
51+
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
52+
xr_assert_allclose(res_i, expected_res_i)
53+
54+
55+
def test_xtensor_variable_transpose():
56+
"""Test the transpose() method of XTensorVariable."""
57+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
58+
59+
# Test basic transpose
60+
out = x.transpose()
61+
fn = xr_function([x], out)
62+
x_test = xr_arange_like(x)
63+
xr_assert_allclose(fn(x_test), x_test.transpose())
64+
65+
# Test transpose with specific dimensions
66+
out = x.transpose("c", "a", "b")
67+
fn = xr_function([x], out)
68+
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
69+
70+
# Test transpose with ellipsis
71+
out = x.transpose("c", ...)
72+
fn = xr_function([x], out)
73+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
74+
75+
# Test error cases
76+
with pytest.raises(
77+
ValueError,
78+
match=re.escape(
79+
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
80+
),
81+
):
82+
x.transpose("d")
83+
84+
with pytest.raises(
85+
ValueError,
86+
match=re.escape("Ellipsis (...) can only appear once in the dimensions"),
87+
):
88+
x.transpose("a", ..., "b", ...)
89+
90+
# Test missing_dims parameter
91+
# Test ignore
92+
out = x.transpose("c", ..., "d", missing_dims="ignore")
93+
fn = xr_function([x], out)
94+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
95+
96+
# Test warn
97+
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
98+
out = x.transpose("c", ..., "d", missing_dims="warn")
99+
fn = xr_function([x], out)
100+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
101+
102+
103+
def test_xtensor_variable_T():
104+
"""Test the T property of XTensorVariable."""
105+
# Test T property with 3D tensor
106+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
107+
out = x.T
108+
109+
fn = xr_function([x], out)
110+
x_test = xr_arange_like(x)
111+
xr_assert_allclose(fn(x_test), x_test.T)
112+
113+
31114
def test_stack():
32115
dims = ("a", "b", "c", "d")
33116
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))

tests/xtensor/test_type.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable():
3333
assert x.type.filter_variable(y1) is y1
3434

3535
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
36-
expected_y2 = as_xtensor(y2.values.transpose(), dims=("a", "b"))
36+
expected_y2 = y2.transpose()
3737
assert equal_computations([x.type.filter_variable(y2)], [expected_y2])
3838

3939
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
4040
expected_y3 = as_xtensor(
41-
specify_shape(
42-
as_xtensor(y3.values.transpose(), dims=("a", "b")).values, (2, 3)
43-
),
44-
dims=("a", "b"),
41+
specify_shape(y3.transpose().values, (2, 3)), dims=("a", "b")
4542
)
4643
assert equal_computations([x.type.filter_variable(y3)], [expected_y3])
4744

@@ -116,7 +113,7 @@ def test_minimum_compile():
116113
from pytensor.compile.mode import Mode
117114

118115
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
119-
y = as_xtensor(x.values.transpose(), dims=("b", "a"))
116+
y = x.transpose()
120117
minimum_mode = Mode(linker="py", optimizer="minimum_compile")
121118
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
122119
np.testing.assert_array_equal(result, np.ones((3, 2)))

0 commit comments

Comments
 (0)