Skip to content

Commit 86d58b9

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement squeeze for XTensorVariables
1 parent a60d0ec commit 86d58b9

File tree

4 files changed

+240
-3
lines changed

4 files changed

+240
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
2+
from pytensor.tensor import (
3+
broadcast_to,
4+
join,
5+
moveaxis,
6+
specify_shape,
7+
squeeze,
8+
)
39
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
410
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
11+
from pytensor.xtensor.shape import (
12+
Concat,
13+
Squeeze,
14+
Stack,
15+
Transpose,
16+
UnStack,
17+
)
618

719

820
@register_lower_xtensor
@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
105117
x_tensor_transposed = x_tensor.transpose(perm)
106118
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
107119
return [new_out]
120+
121+
122+
@register_lower_xtensor
123+
@node_rewriter([Squeeze])
124+
def local_squeeze_reshape(fgraph, node):
125+
"""Rewrite Squeeze to tensor.squeeze."""
126+
[x] = node.inputs
127+
x_tensor = tensor_from_xtensor(x)
128+
x_dims = x.type.dims
129+
dims_to_remove = node.op.dims
130+
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
131+
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
132+
133+
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,87 @@ def make_node(self, *inputs):
297297

298298
def concat(xtensors, dim: str):
299299
return Concat(dim=dim)(*xtensors)
300+
301+
302+
class Squeeze(XOp):
303+
"""Remove specified dimensions from an XTensorVariable.
304+
305+
Only dimensions that are known statically to be size 1 will be removed.
306+
Symbolic dimensions must be explicitly specified, and are assumed safe.
307+
308+
Parameters
309+
----------
310+
dim : tuple of str
311+
The names of the dimensions to remove.
312+
"""
313+
314+
__props__ = ("dims",)
315+
316+
def __init__(self, dims):
317+
self.dims = tuple(sorted(set(dims)))
318+
319+
def make_node(self, x):
320+
x = as_xtensor(x)
321+
322+
# Validate that dims exist and are size-1 if statically known
323+
dims_to_remove = []
324+
x_dims = x.type.dims
325+
x_shape = x.type.shape
326+
for d in self.dims:
327+
if d not in x_dims:
328+
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
329+
idx = x_dims.index(d)
330+
dim_size = x_shape[idx]
331+
if dim_size is not None and dim_size != 1:
332+
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
333+
dims_to_remove.append(idx)
334+
335+
new_dims = tuple(
336+
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
337+
)
338+
new_shape = tuple(
339+
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
340+
)
341+
342+
out = xtensor(
343+
dtype=x.type.dtype,
344+
shape=new_shape,
345+
dims=new_dims,
346+
)
347+
return Apply(self, [x], [out])
348+
349+
350+
def squeeze(x, dim=None, drop=False, axis=None):
351+
"""Remove dimensions of size 1 from an XTensorVariable."""
352+
x = as_xtensor(x)
353+
354+
# drop parameter is ignored in pytensor.xtensor
355+
if drop is not None:
356+
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
357+
358+
# dim and axis are mutually exclusive
359+
if dim is not None and axis is not None:
360+
raise ValueError("Cannot specify both `dim` and `axis`")
361+
362+
# if axis is specified, it must be a sequence of ints
363+
if axis is not None:
364+
if not isinstance(axis, Sequence):
365+
axis = [axis]
366+
if not all(isinstance(a, int) for a in axis):
367+
raise ValueError("axis must be an integer or a sequence of integers")
368+
369+
# convert axis to dims
370+
dims = tuple(x.type.dims[i] for i in axis)
371+
372+
# if dim is specified, it must be a string or a sequence of strings
373+
if dim is None:
374+
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
375+
elif isinstance(dim, str):
376+
dims = (dim,)
377+
else:
378+
dims = tuple(dim)
379+
380+
if not dims:
381+
return x # no-op if nothing to squeeze
382+
383+
return Squeeze(dims=dims)(x)

pytensor/xtensor/type.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,32 @@ def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
547547
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
548548
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
549549

550+
def squeeze(
551+
self,
552+
dim: Sequence[str] | str | None = None,
553+
drop=None,
554+
axis: int | Sequence[int] | None = None,
555+
):
556+
"""Remove dimensions of size 1 from an XTensorVariable.
557+
558+
Parameters
559+
----------
560+
x : XTensorVariable
561+
The input tensor
562+
dim : str or None or iterable of str, optional
563+
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
564+
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
565+
drop : bool, optional
566+
If drop=True, drop squeezed coordinates instead of making them scalar.
567+
axis : int or iterable of int, optional
568+
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
569+
Returns
570+
-------
571+
XTensorVariable
572+
A new tensor with the specified dimension(s) removed.
573+
"""
574+
return px.shape.squeeze(self, dim, drop, axis)
575+
550576
# ndarray methods
551577
# https://docs.xarray.dev/en/latest/api.html#id7
552578
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11+
import pytest
1112
from xarray import DataArray
1213
from xarray import concat as xr_concat
1314

14-
from pytensor.xtensor.shape import concat, stack, transpose, unstack
15+
from pytensor.xtensor.shape import (
16+
concat,
17+
stack,
18+
transpose,
19+
unstack,
20+
)
1521
from pytensor.xtensor.type import xtensor
1622
from tests.xtensor.util import (
1723
xr_arange_like,
@@ -21,6 +27,9 @@
2127
)
2228

2329

30+
pytest.importorskip("xarray")
31+
32+
2433
def powerset(iterable, min_group_size=0):
2534
"Subsequences of the iterable from shortest to longest."
2635
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
@@ -256,3 +265,94 @@ def test_concat_scalar():
256265
res = fn(x1_test, x2_test)
257266
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
258267
xr_assert_allclose(res, expected_res)
268+
269+
270+
def test_squeeze():
271+
"""Test squeeze."""
272+
273+
# Single dimension
274+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
275+
y1 = x1.squeeze("country")
276+
fn1 = xr_function([x1], y1)
277+
x1_test = xr_arange_like(x1)
278+
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))
279+
280+
# Multiple dimensions and order independence
281+
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
282+
y2a = x2.squeeze(["b", "c"])
283+
y2b = x2.squeeze(["c", "b"]) # Test order independence
284+
y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions
285+
y2d = x2.squeeze([]) # Test empty list (no-op)
286+
fn2a = xr_function([x2], y2a)
287+
fn2b = xr_function([x2], y2b)
288+
fn2c = xr_function([x2], y2c)
289+
fn2d = xr_function([x2], y2d)
290+
x2_test = xr_arange_like(x2)
291+
xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"]))
292+
xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"]))
293+
xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"]))
294+
xr_assert_allclose(fn2d(x2_test), x2_test)
295+
296+
# Unknown shapes
297+
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
298+
y3 = x3.squeeze("b")
299+
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
300+
fn3 = xr_function([x3], y3)
301+
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))
302+
303+
# Mixed known + unknown shapes
304+
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
305+
y4 = x4.squeeze("b")
306+
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
307+
fn4 = xr_function([x4], y4)
308+
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
309+
310+
# Test axis parameter
311+
x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3))
312+
y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b)
313+
fn5 = xr_function([x5], y5)
314+
x5_test = xr_arange_like(x5)
315+
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
316+
317+
# Test axis parameter with negative index
318+
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
319+
fn5 = xr_function([x5], y5)
320+
x5_test = xr_arange_like(x5)
321+
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
322+
323+
# Test axis parameter with sequence of ints
324+
y6 = x2.squeeze(axis=[1, 2])
325+
fn6 = xr_function([x2], y6)
326+
x2_test = xr_arange_like(x2)
327+
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
328+
329+
# Test drop parameter warning
330+
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
331+
with pytest.warns(
332+
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
333+
):
334+
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
335+
fn7 = xr_function([x7], y7)
336+
x7_test = xr_arange_like(x7)
337+
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
338+
339+
340+
def test_squeeze_errors():
341+
"""Test error cases for squeeze."""
342+
343+
# Non-existent dimension
344+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
345+
with pytest.raises(ValueError, match="Dimension .* not found"):
346+
x1.squeeze("time")
347+
348+
# Dimension size > 1
349+
with pytest.raises(ValueError, match="has static size .* not 1"):
350+
x1.squeeze("city")
351+
352+
# Symbolic shape: dim is not 1 at runtime → should raise
353+
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
354+
y2 = x2.squeeze("b")
355+
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
356+
fn2 = xr_function([x2], y2)
357+
with pytest.raises(Exception):
358+
fn2(x2_test)

0 commit comments

Comments
 (0)