Skip to content

Adding expand_dims for xtensor #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented Jun 6, 2025

Add expand_dims operation for labeled tensors

This PR adds support for the expand_dims operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.

Key Features

  • New ExpandDims operation that adds a new dimension to an XTensorVariable
  • Support for both static and symbolic dimension sizes
  • Automatic broadcasting when size > 1
  • Integration with existing tensor operations
  • Full compatibility with xarray's expand_dims behavior

Implementation Details

The implementation includes:

  1. New ExpandDims class in pytensor/xtensor/shape.py that handles:

    • Adding new dimensions with specified names
    • Support for both static and symbolic sizes
    • Shape inference and validation
  2. Rewriting rule in pytensor/xtensor/rewriting/shape.py that:

    • Converts labeled tensor operations to standard tensor operations
    • Handles broadcasting when needed
    • Validates symbolic sizes
  3. Comprehensive test suite in tests/xtensor/test_shape.py covering:

    • Basic dimension expansion
    • Static and symbolic sizes
    • Error cases and edge cases
    • Compatibility with xarray operations
    • Integration with other labeled tensor operations

Usage Example

import pytensor.tensor as pt
from pytensor.xtensor import xtensor

# Create a labeled tensor
x = xtensor("x", dims=("city",), shape=(3,))

# Add a new dimension
y = expand_dims(x, "country")  # Adds a new dimension of size 1
z = expand_dims(x, "country", size=4)  # Adds a new dimension of size 4

Testing

The implementation includes extensive tests that verify:

  • Correct behavior with various input shapes
  • Proper handling of symbolic sizes
  • Error cases (invalid dimensions, sizes, etc.)
  • Compatibility with xarray's expand_dims
  • Integration with other labeled tensor operations

📚 Documentation preview 📚: https://pytensor--1449.org.readthedocs.build/en/1449/

@AllenDowney
Copy link
Author

Now that we have this PR based on the right commit, @ricardoV94 it is ready for a first look.

One question: my first draft of this was based on a later commit -- this draft goes back to an earlier commit, and it looks like @register_xcanonicalize doesn't exist yet, so I've replaced it with @register_lower_xtensor, which seems to be its predecessor. Is that the right thing to do for now?

@ricardoV94
Copy link
Member

That's the new name, it better represents the kind of rewrites it holds


def __init__(self, dim, size=1):
self.dims = dim
self.size = size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not allow symbolic sizes, check UnStack for reference

@AllenDowney
Copy link
Author

@ricardoV94 I think this is a step toward handling symbolic sizes, but there are a couple of place where I'm not sure what the right behavior is. See the comments in test_shape.py, test_expand_dims_implicit.

Do those tests make sense? Are there more cases that should be covered?

@ricardoV94
Copy link
Member

The simplest test for symbolic expand_dims is:

size_new_dim = xtensor("size_new_dim", shape=(), dtype=int)
x = xtensor("x", shape=(3,))
y =  x.expand_dims(new_dim=size_new_dim)
xr_function = function([x, size_new_dim], y)

x_test = xr_arange_like(x)
size_new_dim_test = DataArray(np.array(5, dtype=int))
result = xr_function(x_test, size_new_dim_test)
expected_result = x_test.expand_dims(new_dim=size_new_dim_test)
xr_assert_allclose(result, expected_result)

Yout can parametrize the test to try default and explicit non-default axis as well.

Sidenote, what is an implicit expand_dims? I don't think that's a thing.

@@ -369,3 +380,195 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)


def test_expand_dims_explicit():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of redundancy in these tests, nothing is gained from testing 1D, 2D, 3D...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During development, I think it's useful because if something fails, it fails on the simplest case, which makes it easier to debug. Once we have an implementation we think is correct, we could reduce the number of tests, but none of them take long to run, so I'd be inclined to leave them.

Copy link
Member

@ricardoV94 ricardoV94 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me push back on this argument. This test will run thousands to times in a year, small costs add up. For some people a nice test suite takes milliseconds to run for other under a few hours. Speed is subjective

I can't conceive of a scenario where expand_dims would fail for 2d but work for 3d or vice versa. If so why not test up to nd? Aldo the 2d case is tested at least twice below when you tests "multiple dims" and check implicit and explicit size 1

My comment was also a bit more broad, expand_dims is not such a complex operation that we need to test to exhausition. Right now I think it's the Op with most checks after indexing (which is actually really complex).

We also know AI suggestions tend to err on the side of verbosity not conciseness and I suppose that translates to the tests


# Duplicate dimension creation
y = expand_dims(x, "new")
with pytest.raises(ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match the expected error messages to be sure you are triggering the branch you care about. Sometimes you are testing an earlier error and can't tell because of only checking for ValueError/TypeError

@AllenDowney
Copy link
Author

@ricardoV94 I've addressed most of your comments on the previous round, and made a first pass at adding support for multiple dimensions. Please take a look at the expand_dims wrapper function, which canonicalizes the inputs and loops through them to make a series of Ops.

Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions?


# Test behavior with symbolic size > 1
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size.
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the general claim in the note, but at least in this case, it seems like we're getting the behavior we want from xtensor, but running the same operation with xarray does something different, causing the test to fail. Here's Cursor's summary

The test failure confirms that our current implementation of expand_dims broadcasts to the requested size (4 in this case), while xarray's behavior is to always add a size-1 dimension. This is evident from the test output, where the left side (our implementation) has a shape of (batch: 4, a: 2, b: 3), and the right side (xarray's behavior) has a shape of (batch: 1, a: 2, b: 3).

I'm inclined to keep this test to note the discrepancy.

Copy link
Member

@ricardoV94 ricardoV94 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the point. The test shows xarray accepts expand_dims({"batch": 4}) and I guess also expand_dims(batch=4).

That is clearly at odds with the comment that xarray always expands with size of 1.

And that's exactly the behavior we want to replicate. If the size kwarg is something that doesn't exist in xarray (I suspect it doesn't, how would you map each size to each new dim?) we shouldn't introduce it, we want to mimick their API.

@ricardoV94
Copy link
Member

Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions?

That's fine. We used that for other Ops and we can revisit later of we want it to be fused

@AllenDowney
Copy link
Author

@ricardoV94 This is ready for another look.

The rewrite was a shambles, but I think I have a clearer idea now.

)
xr_assert_allclose(fn(xr_arange_like(x)), expected)

# Insert new dim between existing dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this test is meant to cover the axis argument? As it stands it doesn't make sense, it's testing transpose explicitly which is tested elsewhere

expected = x_test.expand_dims("new").transpose("a", "new", "b")
xr_assert_allclose(fn(x_test), expected)

# Expand with multiple dims
Copy link
Member

@ricardoV94 ricardoV94 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this test is intended to test multiple dims in a single call instead?

You already do that in the last test though

size_sym_1 = scalar("size_sym_1", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym_1)
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
Copy link
Member

@ricardoV94 ricardoV94 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unused input is a code smell

fn = xr_function([x], y)
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
res = fn(x_test)
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the comment explanation write the expected to mimick the Pytensor expression

Suggested change
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b"
expected = x_test.expand_dims({"d": x_test.sizes["b"]})

expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b"
xr_assert_allclose(res, expected)

# Test broadcasting with symbolic size from a different tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see the point of this case

# expected = x_test.expand_dims("batch") # always size 1
# xr_assert_allclose(res, expected)

# Test using symbolic size from a reduction operation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see the point of this test

expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b"
xr_assert_allclose(res, expected)

# Test chaining expand_dims with symbolic sizes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are trying really hard to not just use a symbolic size variable, why?

expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2})
xr_assert_allclose(res, expected)

# Test bidirectional broadcasting with symbolic sizes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see the point of this test

size_sym_1 = scalar("size_sym_1", dtype="int64")
size_sym_2 = scalar("size_sym_2", dtype="int64")
y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2})
fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On unused input is suspect

result_tensor = expand_dims(x_tensor, new_axis)
else:
# First expand with size 1
expanded = expand_dims(x_tensor, new_axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small eager optimization, since the new axis is always on the left you don't need the expand_dims step and can skip directly to boradcast_to.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 10, 2025

I left some comments above.

Rewrite looks good. As we discussed we should redo the tests to use expand_dims as a method (like xarray users would).

Also, I suspected xarray allows specifying the size like x.expand_dims(dim_a=1, dim_b=2) which is equivalent to x.expand_dims({"dim_a":1, "dim_b":2}). At least that was a pattern I noticed in other xarray methods. I saw you had a test for multiple dims with dict, but I didn't see one with kwargs.


@register_lower_xtensor
@node_rewriter([ExpandDims])
def local_expand_dims_reshape(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def local_expand_dims_reshape(fgraph, node):
def lower_expand_dims(fgraph, node):

Comment on lines +406 to +431
# Check if size is a valid type before converting
if not (
isinstance(size, int | np.integer)
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0)
):
raise TypeError(
f"size must be an int or scalar variable, got: {type(size)}"
)

# Determine shape
try:
static_size = get_scalar_constant_value(size)
except NotScalarConstantError:
static_size = None

if static_size is not None:
new_shape = (int(static_size), *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size

# Convert size to tensor
size = as_xtensor(size, dims=())

# Check if size is a constant and validate it
if isinstance(size, Constant) and size.data < 0:
raise ValueError(f"size must be 0 or positive, got: {size.data}")
Copy link
Member

@ricardoV94 ricardoV94 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit more clean?

Suggested change
# Check if size is a valid type before converting
if not (
isinstance(size, int | np.integer)
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0)
):
raise TypeError(
f"size must be an int or scalar variable, got: {type(size)}"
)
# Determine shape
try:
static_size = get_scalar_constant_value(size)
except NotScalarConstantError:
static_size = None
if static_size is not None:
new_shape = (int(static_size), *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size
# Convert size to tensor
size = as_xtensor(size, dims=())
# Check if size is a constant and validate it
if isinstance(size, Constant) and size.data < 0:
raise ValueError(f"size must be 0 or positive, got: {size.data}")
size = as_xtensor(size, dims=())
if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}")
try:
static_size = int(get_scalar_constant_value(size))
except NotScalarConstantError:
static_size = None
# If size is a constant, validate it
if static_size is not None and static_size < 0:
raise ValueError(f"size must be 0 or positive, got: {static_size}")
new_shape = (static_size, *x.type.shape)

y = expand_dims(x, "batch", size=size_sym_1)
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you already have the function you can use it twice (makes comment before stale)

Suggested change
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))
xr_assert_allclose(fn(x_test, 5), x_test.expand_dims({"batch": 5}))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants