Skip to content

Commit fd14b8d

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement transpose for XTensorVariables
1 parent de5aead commit fd14b8d

File tree

4 files changed

+238
-3
lines changed

4 files changed

+238
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose
66

77

88
@register_lower_xtensor
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
7070
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
7171
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
7272
return [new_out]
73+
74+
75+
@register_lower_xtensor
76+
@node_rewriter(tracks=[Transpose])
77+
def lower_transpose(fgraph, node):
78+
[x] = node.inputs
79+
# Use the final dimensions that were already computed in make_node
80+
out_dims = node.outputs[0].type.dims
81+
in_dims = x.type.dims
82+
83+
# Compute the permutation based on the final dimensions
84+
perm = tuple(in_dims.index(d) for d in out_dims)
85+
x_tensor = tensor_from_xtensor(x)
86+
x_tensor_transposed = x_tensor.transpose(perm)
87+
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
88+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import warnings
12
from collections.abc import Sequence
3+
from types import EllipsisType
4+
from typing import Literal
25

36
from pytensor.graph import Apply
47
from pytensor.scalar import upcast
@@ -72,6 +75,97 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7275
return y
7376

7477

78+
class Transpose(XOp):
79+
__props__ = ("dims",)
80+
81+
def __init__(
82+
self,
83+
dims: tuple[str | EllipsisType, ...],
84+
):
85+
super().__init__()
86+
if dims.count(...) > 1:
87+
raise ValueError("an index can only have a single ellipsis ('...')")
88+
self.dims = dims
89+
90+
def make_node(self, x):
91+
x = as_xtensor(x)
92+
93+
transpose_dims = self.dims
94+
x_dims = x.type.dims
95+
96+
if transpose_dims == () or transpose_dims == (...,):
97+
out_dims = tuple(reversed(x_dims))
98+
elif ... in transpose_dims:
99+
# Handle ellipsis expansion
100+
ellipsis_idx = transpose_dims.index(...)
101+
pre = transpose_dims[:ellipsis_idx]
102+
post = transpose_dims[ellipsis_idx + 1 :]
103+
middle = [d for d in x_dims if d not in pre + post]
104+
out_dims = (*pre, *middle, *post)
105+
if set(out_dims) != set(x_dims):
106+
raise ValueError(f"{out_dims} must be a permuted list of {x_dims}")
107+
else:
108+
out_dims = transpose_dims
109+
if set(out_dims) != set(x_dims):
110+
raise ValueError(
111+
f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included"
112+
)
113+
114+
output = xtensor(
115+
dtype=x.type.dtype,
116+
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims),
117+
dims=out_dims,
118+
)
119+
return Apply(self, [x], [output])
120+
121+
122+
def transpose(
123+
x,
124+
*dims: str | EllipsisType,
125+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
126+
):
127+
"""Transpose dimensions of the tensor.
128+
129+
Parameters
130+
----------
131+
x : XTensorVariable
132+
Input tensor to transpose.
133+
*dims : str
134+
Dimensions to transpose to. Can include ellipsis (...) to represent
135+
remaining dimensions in their original order.
136+
missing_dims : {"raise", "warn", "ignore"}, optional
137+
How to handle dimensions that don't exist in the input tensor:
138+
- "raise": Raise an error if any dimensions don't exist (default)
139+
- "warn": Warn if any dimensions don't exist
140+
- "ignore": Silently ignore any dimensions that don't exist
141+
142+
Returns
143+
-------
144+
XTensorVariable
145+
Transposed tensor with reordered dimensions.
146+
147+
Raises
148+
------
149+
ValueError
150+
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
151+
"""
152+
# Validate dimensions
153+
x = as_xtensor(x)
154+
all_dims = x.type.dims
155+
invalid_dims = set(dims) - {..., *all_dims}
156+
if invalid_dims:
157+
if missing_dims != "ignore":
158+
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}"
159+
if missing_dims == "raise":
160+
raise ValueError(msg)
161+
else:
162+
warnings.warn(msg)
163+
# Handle missing dimensions if not raising
164+
dims = tuple(d for d in dims if d in all_dims or d is ...)
165+
166+
return Transpose(dims)(x)
167+
168+
75169
class Concat(XOp):
76170
__props__ = ("dim",)
77171

pytensor/xtensor/type.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing
2+
from types import EllipsisType
23

34
from pytensor.compile import (
45
DeepCopyOp,
@@ -18,7 +19,7 @@
1819
XARRAY_AVAILABLE = False
1920

2021
from collections.abc import Sequence
21-
from typing import TypeVar
22+
from typing import Literal, TypeVar
2223

2324
import numpy as np
2425

@@ -362,6 +363,19 @@ def imag(self):
362363
def real(self):
363364
return px.math.real(self)
364365

366+
@property
367+
def T(self):
368+
"""Return the full transpose of the tensor.
369+
370+
This is equivalent to calling transpose() with no arguments.
371+
372+
Returns
373+
-------
374+
XTensorVariable
375+
Fully transposed tensor.
376+
"""
377+
return self.transpose()
378+
365379
# Aggregation
366380
# https://docs.xarray.dev/en/latest/api.html#id6
367381
def all(self, dim):
@@ -399,6 +413,37 @@ def cumprod(self, dim):
399413

400414
# Reshaping and reorganizing
401415
# https://docs.xarray.dev/en/latest/api.html#id8
416+
def transpose(
417+
self,
418+
*dims: str | EllipsisType,
419+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
420+
):
421+
"""Transpose dimensions of the tensor.
422+
423+
Parameters
424+
----------
425+
*dims : str | Ellipsis
426+
Dimensions to transpose. If empty, performs a full transpose.
427+
Can use ellipsis (...) to represent remaining dimensions.
428+
missing_dims : {"raise", "warn", "ignore"}, default="raise"
429+
How to handle dimensions that don't exist in the tensor:
430+
- "raise": Raise an error if any dimensions don't exist
431+
- "warn": Warn if any dimensions don't exist
432+
- "ignore": Silently ignore any dimensions that don't exist
433+
434+
Returns
435+
-------
436+
XTensorVariable
437+
Transposed tensor with reordered dimensions.
438+
439+
Raises
440+
------
441+
ValueError
442+
If missing_dims="raise" and any dimensions don't exist.
443+
If multiple ellipsis are provided.
444+
"""
445+
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
446+
402447
def stack(self, dim, **dims):
403448
return px.shape.stack(self, dim, **dims)
404449

tests/xtensor/test_shape.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
pytest.importorskip("xarray")
66

7+
import re
78
from itertools import chain, combinations
89

910
import numpy as np
1011
from xarray import concat as xr_concat
1112

12-
from pytensor.xtensor.shape import concat, stack
13+
from pytensor.xtensor.shape import concat, stack, transpose
1314
from pytensor.xtensor.type import xtensor
1415
from tests.xtensor.util import (
1516
xr_arange_like,
@@ -28,6 +29,85 @@ def powerset(iterable, min_group_size=0):
2829
)
2930

3031

32+
def test_transpose():
33+
a, b, c, d, e = "abcde"
34+
35+
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
36+
permutations = [
37+
(a, b, c, d, e), # identity
38+
(e, d, c, b, a), # full tranpose
39+
(), # eqivalent to full transpose
40+
(a, b, c, e, d), # swap last two dims
41+
(..., d, c), # equivalent to (a, b, e, d, c)
42+
(b, a, ..., e, d), # equivalent to (b, a, c, d, e)
43+
(c, a, ...), # equivalent to (c, a, b, d, e)
44+
]
45+
outs = [transpose(x, *perm) for perm in permutations]
46+
47+
fn = xr_function([x], outs)
48+
x_test = xr_arange_like(x)
49+
res = fn(x_test)
50+
expected_res = [x_test.transpose(*perm) for perm in permutations]
51+
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
52+
xr_assert_allclose(res_i, expected_res_i)
53+
54+
55+
def test_xtensor_variable_transpose():
56+
"""Test the transpose() method of XTensorVariable."""
57+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
58+
59+
# Test basic transpose
60+
out = x.transpose()
61+
fn = xr_function([x], out)
62+
x_test = xr_arange_like(x)
63+
xr_assert_allclose(fn(x_test), x_test.transpose())
64+
65+
# Test transpose with specific dimensions
66+
out = x.transpose("c", "a", "b")
67+
fn = xr_function([x], out)
68+
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
69+
70+
# Test transpose with ellipsis
71+
out = x.transpose("c", ...)
72+
fn = xr_function([x], out)
73+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
74+
75+
# Test error cases
76+
with pytest.raises(
77+
ValueError,
78+
match=re.escape(
79+
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
80+
),
81+
):
82+
x.transpose("d")
83+
84+
with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
85+
x.transpose("a", ..., "b", ...)
86+
87+
# Test missing_dims parameter
88+
# Test ignore
89+
out = x.transpose("c", ..., "d", missing_dims="ignore")
90+
fn = xr_function([x], out)
91+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
92+
93+
# Test warn
94+
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
95+
out = x.transpose("c", ..., "d", missing_dims="warn")
96+
fn = xr_function([x], out)
97+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
98+
99+
100+
def test_xtensor_variable_T():
101+
"""Test the T property of XTensorVariable."""
102+
# Test T property with 3D tensor
103+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
104+
out = x.T
105+
106+
fn = xr_function([x], out)
107+
x_test = xr_arange_like(x)
108+
xr_assert_allclose(fn(x_test), x_test.T)
109+
110+
31111
def test_stack():
32112
dims = ("a", "b", "c", "d")
33113
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))

0 commit comments

Comments
 (0)