Skip to content

Commit f160396

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement expand_dims for XTensorVariables (#1449)
1 parent 2ab117e commit f160396

File tree

4 files changed

+307
-3
lines changed

4 files changed

+307
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytensor.graph import node_rewriter
22
from pytensor.tensor import (
33
broadcast_to,
4+
expand_dims,
45
join,
56
moveaxis,
67
specify_shape,
@@ -10,6 +11,7 @@
1011
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1112
from pytensor.xtensor.shape import (
1213
Concat,
14+
ExpandDims,
1315
Squeeze,
1416
Stack,
1517
Transpose,
@@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):
121123

122124
@register_lower_xtensor
123125
@node_rewriter([Squeeze])
124-
def local_squeeze_reshape(fgraph, node):
126+
def lower_squeeze(fgraph, node):
125127
"""Rewrite Squeeze to tensor.squeeze."""
126128
[x] = node.inputs
127129
x_tensor = tensor_from_xtensor(x)
@@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):
132134

133135
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134136
return [new_out]
137+
138+
139+
@register_lower_xtensor
140+
@node_rewriter([ExpandDims])
141+
def lower_expand_dims(fgraph, node):
142+
"""Rewrite ExpandDims using tensor operations."""
143+
x, size = node.inputs
144+
out = node.outputs[0]
145+
146+
# Convert inputs to tensors
147+
x_tensor = tensor_from_xtensor(x)
148+
size_tensor = tensor_from_xtensor(size)
149+
150+
# Get the new dimension name and position
151+
new_axis = 0 # Always insert at front
152+
153+
# Use tensor operations
154+
if out.type.shape[0] == 1:
155+
# Simple case: just expand with size 1
156+
result_tensor = expand_dims(x_tensor, new_axis)
157+
else:
158+
# Otherwise broadcast to the requested size
159+
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
160+
161+
# Preserve static shape information
162+
result_tensor = specify_shape(result_tensor, out.type.shape)
163+
164+
# Convert result back to xtensor
165+
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
166+
return [result]

pytensor/xtensor/shape.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import warnings
2-
from collections.abc import Sequence
2+
from collections.abc import Hashable, Sequence
33
from types import EllipsisType
44
from typing import Literal
55

6+
import numpy as np
7+
68
from pytensor.graph import Apply
79
from pytensor.scalar import discrete_dtypes, upcast
810
from pytensor.tensor import as_tensor, get_scalar_constant_value
911
from pytensor.tensor.exceptions import NotScalarConstantError
12+
from pytensor.tensor.type import integer_dtypes
1013
from pytensor.xtensor.basic import XOp
1114
from pytensor.xtensor.type import as_xtensor, xtensor
1215

@@ -385,3 +388,121 @@ def squeeze(x, dim=None, drop=False, axis=None):
385388
return x # no-op if nothing to squeeze
386389

387390
return Squeeze(dims=dims)(x)
391+
392+
393+
class ExpandDims(XOp):
394+
"""Add a new dimension to an XTensorVariable."""
395+
396+
__props__ = ("dim",)
397+
398+
def __init__(self, dim):
399+
if not isinstance(dim, str):
400+
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")
401+
402+
self.dim = dim
403+
404+
def make_node(self, x, size):
405+
x = as_xtensor(x)
406+
407+
if self.dim in x.type.dims:
408+
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
409+
410+
size = as_xtensor(size, dims=())
411+
if not (size.dtype in integer_dtypes and size.ndim == 0):
412+
raise ValueError(f"size should be an integer scalar, got {size.type}")
413+
try:
414+
static_size = int(get_scalar_constant_value(size))
415+
except NotScalarConstantError:
416+
static_size = None
417+
418+
# If size is a constant, validate it
419+
if static_size is not None and static_size < 0:
420+
raise ValueError(f"size must be 0 or positive, got: {static_size}")
421+
new_shape = (static_size, *x.type.shape)
422+
423+
# Insert new dim at front
424+
new_dims = (self.dim, *x.type.dims)
425+
426+
out = xtensor(
427+
dtype=x.type.dtype,
428+
shape=new_shape,
429+
dims=new_dims,
430+
)
431+
return Apply(self, [x, size], [out])
432+
433+
434+
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
435+
"""Add one or more new dimensions to an XTensorVariable."""
436+
x = as_xtensor(x)
437+
438+
# Store original dimensions for axis handling
439+
original_dims = x.type.dims
440+
441+
# Warn if create_index_for_new_dim is used (not supported)
442+
if create_index_for_new_dim is not None:
443+
warnings.warn(
444+
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
445+
UserWarning,
446+
stacklevel=2,
447+
)
448+
449+
if dim is None:
450+
dim = dim_kwargs
451+
elif dim_kwargs:
452+
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
453+
454+
# Check that dim is Hashable or a sequence of Hashable or dict
455+
if not isinstance(dim, Hashable):
456+
if not isinstance(dim, Sequence | dict):
457+
raise TypeError(f"unhashable type: {type(dim).__name__}")
458+
if not all(isinstance(d, Hashable) for d in dim):
459+
raise TypeError(f"unhashable type in {type(dim).__name__}")
460+
461+
# Normalize to a dimension-size mapping
462+
if isinstance(dim, str):
463+
dims_dict = {dim: 1}
464+
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
465+
dims_dict = {d: 1 for d in dim}
466+
elif isinstance(dim, dict):
467+
dims_dict = {}
468+
for name, val in dim.items():
469+
if isinstance(val, str):
470+
raise TypeError(f"Dimension size cannot be a string: {val}")
471+
if isinstance(val, Sequence | np.ndarray):
472+
warnings.warn(
473+
"When a sequence is provided as a dimension size, only its length is used. "
474+
"The actual values (which would be coordinates in xarray) are ignored.",
475+
UserWarning,
476+
stacklevel=2,
477+
)
478+
dims_dict[name] = len(val)
479+
else:
480+
# should be int or symbolic scalar
481+
dims_dict[name] = val
482+
else:
483+
raise TypeError(f"Invalid type for `dim`: {type(dim)}")
484+
485+
# Insert each new dim at the front (reverse order preserves user intent)
486+
for name, size in reversed(dims_dict.items()):
487+
x = ExpandDims(dim=name)(x, size)
488+
489+
# If axis is specified, transpose to put new dimensions in the right place
490+
if axis is not None:
491+
# Wrap non-sequence axis in a list
492+
if not isinstance(axis, Sequence):
493+
axis = [axis]
494+
495+
# require len(axis) == len(dims_dict)
496+
if len(axis) != len(dims_dict):
497+
raise ValueError("lengths of dim and axis should be identical.")
498+
499+
# Insert new dimensions at their specified positions
500+
target_dims = list(original_dims)
501+
for name, pos in zip(dims_dict, axis):
502+
# Convert negative axis to positive position relative to current dims
503+
if pos < 0:
504+
pos = len(target_dims) + pos + 1
505+
target_dims.insert(pos, name)
506+
x = Transpose(dims=tuple(target_dims))(x)
507+
508+
return x

pytensor/xtensor/type.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,47 @@ def squeeze(
496496
"""
497497
return px.shape.squeeze(self, dim, drop, axis)
498498

499+
def expand_dims(
500+
self,
501+
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
502+
create_index_for_new_dim: bool = True,
503+
axis: int | Sequence[int] | None = None,
504+
**dim_kwargs,
505+
):
506+
"""Add one or more new dimensions to the tensor.
507+
508+
Parameters
509+
----------
510+
dim : str | Sequence[str] | dict[str, int | Sequence] | None
511+
If str or sequence of str, new dimensions with size 1.
512+
If dict, keys are dimension names and values are either:
513+
- int: the new size
514+
- sequence: coordinates (length determines size)
515+
create_index_for_new_dim : bool, default: True
516+
Currently ignored. Reserved for future coordinate support.
517+
In xarray, when True (default), creates a coordinate index for the new dimension
518+
with values from 0 to size-1. When False, no coordinate index is created.
519+
axis : int | Sequence[int] | None, default: None
520+
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
521+
By default (None), new dimensions are inserted at the beginning (axis=0).
522+
Symbolic axis is not supported yet.
523+
Negative values count from the end.
524+
**dim_kwargs : int | Sequence
525+
Alternative to `dim` dict. Only used if `dim` is None.
526+
527+
Returns
528+
-------
529+
XTensorVariable
530+
A tensor with additional dimensions inserted at the front.
531+
"""
532+
return px.shape.expand_dims(
533+
self,
534+
dim,
535+
create_index_for_new_dim=create_index_for_new_dim,
536+
axis=axis,
537+
**dim_kwargs,
538+
)
539+
499540
# ndarray methods
500541
# https://docs.xarray.dev/en/latest/api.html#id7
501542
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import pytest
1211
from xarray import DataArray
1312
from xarray import concat as xr_concat
1413

14+
from pytensor.tensor import scalar
1515
from pytensor.xtensor.shape import (
1616
concat,
1717
stack,
@@ -353,3 +353,113 @@ def test_squeeze_errors():
353353
fn2 = xr_function([x2], y2)
354354
with pytest.raises(Exception):
355355
fn2(x2_test)
356+
357+
358+
def test_expand_dims():
359+
"""Test expand_dims."""
360+
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
361+
x_test = xr_arange_like(x)
362+
363+
# Implicit size 1
364+
y = x.expand_dims("country")
365+
fn = xr_function([x], y)
366+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
367+
368+
# Test with multiple dimensions
369+
y = x.expand_dims(["country", "state"])
370+
fn = xr_function([x], y)
371+
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
372+
373+
# Test with a dict of name-size pairs
374+
y = x.expand_dims({"country": 2, "state": 3})
375+
fn = xr_function([x], y)
376+
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
377+
378+
# Test with kwargs (equivalent to dict)
379+
y = x.expand_dims(country=2, state=3)
380+
fn = xr_function([x], y)
381+
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
382+
383+
# Test with a dict of name-coord array pairs
384+
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
385+
fn = xr_function([x], y)
386+
xr_assert_allclose(
387+
fn(x_test),
388+
x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}),
389+
)
390+
391+
# Symbolic size 1
392+
size_sym_1 = scalar("size_sym_1", dtype="int64")
393+
y = x.expand_dims({"country": size_sym_1})
394+
fn = xr_function([x, size_sym_1], y)
395+
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
396+
397+
# Test with symbolic sizes in dict
398+
size_sym_2 = scalar("size_sym_2", dtype="int64")
399+
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
400+
fn = xr_function([x, size_sym_1, size_sym_2], y)
401+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
402+
403+
# Test with symbolic sizes in kwargs
404+
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
405+
fn = xr_function([x, size_sym_1, size_sym_2], y)
406+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
407+
408+
# Test with axis parameter
409+
y = x.expand_dims("country", axis=1)
410+
fn = xr_function([x], y)
411+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
412+
413+
# Test with negative axis parameter
414+
y = x.expand_dims("country", axis=-1)
415+
fn = xr_function([x], y)
416+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
417+
418+
# Add two new dims with axis parameters
419+
y = x.expand_dims(["country", "state"], axis=[1, 2])
420+
fn = xr_function([x], y)
421+
xr_assert_allclose(
422+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
423+
)
424+
425+
# Add two dims with negative axis parameters
426+
y = x.expand_dims(["country", "state"], axis=[-1, -2])
427+
fn = xr_function([x], y)
428+
xr_assert_allclose(
429+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2])
430+
)
431+
432+
# Add two dims with positive and negative axis parameters
433+
y = x.expand_dims(["country", "state"], axis=[-2, 1])
434+
fn = xr_function([x], y)
435+
xr_assert_allclose(
436+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1])
437+
)
438+
439+
440+
def test_expand_dims_errors():
441+
"""Test error handling in expand_dims."""
442+
443+
# Expanding existing dim
444+
x = xtensor("x", dims=("city",), shape=(3,))
445+
y = x.expand_dims("country")
446+
with pytest.raises(ValueError, match="already exists"):
447+
y.expand_dims("city")
448+
449+
# Invalid dim type
450+
with pytest.raises(TypeError, match="Invalid type for `dim`"):
451+
x.expand_dims(123)
452+
453+
# Duplicate dimension creation
454+
y = x.expand_dims("new")
455+
with pytest.raises(ValueError, match="already exists"):
456+
y.expand_dims("new")
457+
458+
# Find out what xarray does with a numpy array as dim
459+
# x_test = xr_arange_like(x)
460+
# x_test.expand_dims(np.array([1, 2]))
461+
# TypeError: unhashable type: 'numpy.ndarray'
462+
463+
# Test with a numpy array as dim (not supported)
464+
with pytest.raises(TypeError, match="unhashable type"):
465+
y.expand_dims(np.array([1, 2]))

0 commit comments

Comments
 (0)