Skip to content

Commit de785b3

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement squeeze for XTensorVariables
1 parent 9aeb80f commit de785b3

File tree

4 files changed

+240
-3
lines changed

4 files changed

+240
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
2+
from pytensor.tensor import (
3+
broadcast_to,
4+
join,
5+
moveaxis,
6+
specify_shape,
7+
squeeze,
8+
)
39
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
410
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
11+
from pytensor.xtensor.shape import (
12+
Concat,
13+
Squeeze,
14+
Stack,
15+
Transpose,
16+
UnStack,
17+
)
618

719

820
@register_lower_xtensor
@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
105117
x_tensor_transposed = x_tensor.transpose(perm)
106118
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
107119
return [new_out]
120+
121+
122+
@register_lower_xtensor
123+
@node_rewriter([Squeeze])
124+
def local_squeeze_reshape(fgraph, node):
125+
"""Rewrite Squeeze to tensor.squeeze."""
126+
[x] = node.inputs
127+
x_tensor = tensor_from_xtensor(x)
128+
x_dims = x.type.dims
129+
dims_to_remove = node.op.dims
130+
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
131+
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
132+
133+
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,87 @@ def make_node(self, *inputs):
301301

302302
def concat(xtensors, dim: str):
303303
return Concat(dim=dim)(*xtensors)
304+
305+
306+
class Squeeze(XOp):
307+
"""Remove specified dimensions from an XTensorVariable.
308+
309+
Only dimensions that are known statically to be size 1 will be removed.
310+
Symbolic dimensions must be explicitly specified, and are assumed safe.
311+
312+
Parameters
313+
----------
314+
dim : tuple of str
315+
The names of the dimensions to remove.
316+
"""
317+
318+
__props__ = ("dims",)
319+
320+
def __init__(self, dims):
321+
self.dims = tuple(sorted(set(dims)))
322+
323+
def make_node(self, x):
324+
x = as_xtensor(x)
325+
326+
# Validate that dims exist and are size-1 if statically known
327+
dims_to_remove = []
328+
x_dims = x.type.dims
329+
x_shape = x.type.shape
330+
for d in self.dims:
331+
if d not in x_dims:
332+
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
333+
idx = x_dims.index(d)
334+
dim_size = x_shape[idx]
335+
if dim_size is not None and dim_size != 1:
336+
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
337+
dims_to_remove.append(idx)
338+
339+
new_dims = tuple(
340+
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
341+
)
342+
new_shape = tuple(
343+
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
344+
)
345+
346+
out = xtensor(
347+
dtype=x.type.dtype,
348+
shape=new_shape,
349+
dims=new_dims,
350+
)
351+
return Apply(self, [x], [out])
352+
353+
354+
def squeeze(x, dim=None, drop=False, axis=None):
355+
"""Remove dimensions of size 1 from an XTensorVariable."""
356+
x = as_xtensor(x)
357+
358+
# drop parameter is ignored in pytensor.xtensor
359+
if drop is not None:
360+
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
361+
362+
# dim and axis are mutually exclusive
363+
if dim is not None and axis is not None:
364+
raise ValueError("Cannot specify both `dim` and `axis`")
365+
366+
# if axis is specified, it must be a sequence of ints
367+
if axis is not None:
368+
if not isinstance(axis, Sequence):
369+
axis = [axis]
370+
if not all(isinstance(a, int) for a in axis):
371+
raise ValueError("axis must be an integer or a sequence of integers")
372+
373+
# convert axis to dims
374+
dims = tuple(x.type.dims[i] for i in axis)
375+
376+
# if dim is specified, it must be a string or a sequence of strings
377+
if dim is None:
378+
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
379+
elif isinstance(dim, str):
380+
dims = (dim,)
381+
else:
382+
dims = tuple(dim)
383+
384+
if not dims:
385+
return x # no-op if nothing to squeeze
386+
387+
return Squeeze(dims=dims)(x)

pytensor/xtensor/type.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,32 @@ def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
470470
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
471471
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
472472

473+
def squeeze(
474+
self,
475+
dim: Sequence[str] | str | None = None,
476+
drop=None,
477+
axis: int | Sequence[int] | None = None,
478+
):
479+
"""Remove dimensions of size 1 from an XTensorVariable.
480+
481+
Parameters
482+
----------
483+
x : XTensorVariable
484+
The input tensor
485+
dim : str or None or iterable of str, optional
486+
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
487+
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
488+
drop : bool, optional
489+
If drop=True, drop squeezed coordinates instead of making them scalar.
490+
axis : int or iterable of int, optional
491+
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
492+
Returns
493+
-------
494+
XTensorVariable
495+
A new tensor with the specified dimension(s) removed.
496+
"""
497+
return px.shape.squeeze(self, dim, drop, axis)
498+
473499
# ndarray methods
474500
# https://docs.xarray.dev/en/latest/api.html#id7
475501
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11+
import pytest
1112
from xarray import DataArray
1213
from xarray import concat as xr_concat
1314

14-
from pytensor.xtensor.shape import concat, stack, transpose, unstack
15+
from pytensor.xtensor.shape import (
16+
concat,
17+
stack,
18+
transpose,
19+
unstack,
20+
)
1521
from pytensor.xtensor.type import xtensor
1622
from tests.xtensor.util import (
1723
xr_arange_like,
@@ -21,6 +27,9 @@
2127
)
2228

2329

30+
pytest.importorskip("xarray")
31+
32+
2433
def powerset(iterable, min_group_size=0):
2534
"Subsequences of the iterable from shortest to longest."
2635
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
@@ -253,3 +262,94 @@ def test_concat_scalar():
253262
res = fn(x1_test, x2_test)
254263
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
255264
xr_assert_allclose(res, expected_res)
265+
266+
267+
def test_squeeze():
268+
"""Test squeeze."""
269+
270+
# Single dimension
271+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
272+
y1 = x1.squeeze("country")
273+
fn1 = xr_function([x1], y1)
274+
x1_test = xr_arange_like(x1)
275+
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))
276+
277+
# Multiple dimensions and order independence
278+
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
279+
y2a = x2.squeeze(["b", "c"])
280+
y2b = x2.squeeze(["c", "b"]) # Test order independence
281+
y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions
282+
y2d = x2.squeeze([]) # Test empty list (no-op)
283+
fn2a = xr_function([x2], y2a)
284+
fn2b = xr_function([x2], y2b)
285+
fn2c = xr_function([x2], y2c)
286+
fn2d = xr_function([x2], y2d)
287+
x2_test = xr_arange_like(x2)
288+
xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"]))
289+
xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"]))
290+
xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"]))
291+
xr_assert_allclose(fn2d(x2_test), x2_test)
292+
293+
# Unknown shapes
294+
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
295+
y3 = x3.squeeze("b")
296+
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
297+
fn3 = xr_function([x3], y3)
298+
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))
299+
300+
# Mixed known + unknown shapes
301+
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
302+
y4 = x4.squeeze("b")
303+
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
304+
fn4 = xr_function([x4], y4)
305+
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
306+
307+
# Test axis parameter
308+
x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3))
309+
y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b)
310+
fn5 = xr_function([x5], y5)
311+
x5_test = xr_arange_like(x5)
312+
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
313+
314+
# Test axis parameter with negative index
315+
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
316+
fn5 = xr_function([x5], y5)
317+
x5_test = xr_arange_like(x5)
318+
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
319+
320+
# Test axis parameter with sequence of ints
321+
y6 = x2.squeeze(axis=[1, 2])
322+
fn6 = xr_function([x2], y6)
323+
x2_test = xr_arange_like(x2)
324+
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
325+
326+
# Test drop parameter warning
327+
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
328+
with pytest.warns(
329+
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
330+
):
331+
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
332+
fn7 = xr_function([x7], y7)
333+
x7_test = xr_arange_like(x7)
334+
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
335+
336+
337+
def test_squeeze_errors():
338+
"""Test error cases for squeeze."""
339+
340+
# Non-existent dimension
341+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
342+
with pytest.raises(ValueError, match="Dimension .* not found"):
343+
x1.squeeze("time")
344+
345+
# Dimension size > 1
346+
with pytest.raises(ValueError, match="has static size .* not 1"):
347+
x1.squeeze("city")
348+
349+
# Symbolic shape: dim is not 1 at runtime → should raise
350+
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
351+
y2 = x2.squeeze("b")
352+
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
353+
fn2 = xr_function([x2], y2)
354+
with pytest.raises(Exception):
355+
fn2(x2_test)

0 commit comments

Comments
 (0)