-
Notifications
You must be signed in to change notification settings - Fork 135
Adding expand_dims for xtensor #1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Adding expand_dims for xtensor #1449
Conversation
Now that we have this PR based on the right commit, @ricardoV94 it is ready for a first look. One question: my first draft of this was based on a later commit -- this draft goes back to an earlier commit, and it looks like |
That's the new name, it better represents the kind of rewrites it holds |
pytensor/xtensor/shape.py
Outdated
|
||
def __init__(self, dim, size=1): | ||
self.dims = dim | ||
self.size = size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not allow symbolic sizes, check UnStack for reference
@ricardoV94 I think this is a step toward handling symbolic sizes, but there are a couple of place where I'm not sure what the right behavior is. See the comments in Do those tests make sense? Are there more cases that should be covered? |
The simplest test for symbolic expand_dims is: size_new_dim = xtensor("size_new_dim", shape=(), dtype=int)
x = xtensor("x", shape=(3,))
y = x.expand_dims(new_dim=size_new_dim)
xr_function = function([x, size_new_dim], y)
x_test = xr_arange_like(x)
size_new_dim_test = DataArray(np.array(5, dtype=int))
result = xr_function(x_test, size_new_dim_test)
expected_result = x_test.expand_dims(new_dim=size_new_dim_test)
xr_assert_allclose(result, expected_result) Yout can parametrize the test to try default and explicit non-default axis as well. Sidenote, what is an implicit |
@@ -369,3 +380,195 @@ def test_squeeze_errors(): | |||
fn2 = xr_function([x2], y2) | |||
with pytest.raises(Exception): | |||
fn2(x2_test) | |||
|
|||
|
|||
def test_expand_dims_explicit(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of redundancy in these tests, nothing is gained from testing 1D, 2D, 3D...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During development, I think it's useful because if something fails, it fails on the simplest case, which makes it easier to debug. Once we have an implementation we think is correct, we could reduce the number of tests, but none of them take long to run, so I'd be inclined to leave them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me push back on this argument. This test will run thousands to times in a year, small costs add up. For some people a nice test suite takes milliseconds to run for other under a few hours. Speed is subjective
I can't conceive of a scenario where expand_dims would fail for 2d but work for 3d or vice versa. If so why not test up to nd? Aldo the 2d case is tested at least twice below when you tests "multiple dims" and check implicit and explicit size 1
My comment was also a bit more broad, expand_dims is not such a complex operation that we need to test to exhausition. Right now I think it's the Op with most checks after indexing (which is actually really complex).
We also know AI suggestions tend to err on the side of verbosity not conciseness and I suppose that translates to the tests
tests/xtensor/test_shape.py
Outdated
|
||
# Duplicate dimension creation | ||
y = expand_dims(x, "new") | ||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match the expected error messages to be sure you are triggering the branch you care about. Sometimes you are testing an earlier error and can't tell because of only checking for ValueError/TypeError
@ricardoV94 I've addressed most of your comments on the previous round, and made a first pass at adding support for multiple dimensions. Please take a look at the Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions? |
|
||
# Test behavior with symbolic size > 1 | ||
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. | ||
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the general claim in the note, but at least in this case, it seems like we're getting the behavior we want from xtensor, but running the same operation with xarray does something different, causing the test to fail. Here's Cursor's summary
The test failure confirms that our current implementation of expand_dims broadcasts to the requested size (4 in this case), while xarray's behavior is to always add a size-1 dimension. This is evident from the test output, where the left side (our implementation) has a shape of (batch: 4, a: 2, b: 3), and the right side (xarray's behavior) has a shape of (batch: 1, a: 2, b: 3).
I'm inclined to keep this test to note the discrepancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the point. The test shows xarray accepts expand_dims({"batch": 4})
and I guess also expand_dims(batch=4)
.
That is clearly at odds with the comment that xarray always expands with size of 1.
And that's exactly the behavior we want to replicate. If the size kwarg is something that doesn't exist in xarray (I suspect it doesn't, how would you map each size to each new dim?) we shouldn't introduce it, we want to mimick their API.
That's fine. We used that for other Ops and we can revisit later of we want it to be fused |
@ricardoV94 This is ready for another look. The rewrite was a shambles, but I think I have a clearer idea now. |
) | ||
xr_assert_allclose(fn(xr_arange_like(x)), expected) | ||
|
||
# Insert new dim between existing dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this test is meant to cover the axis argument? As it stands it doesn't make sense, it's testing transpose explicitly which is tested elsewhere
expected = x_test.expand_dims("new").transpose("a", "new", "b") | ||
xr_assert_allclose(fn(x_test), expected) | ||
|
||
# Expand with multiple dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this test is intended to test multiple dims in a single call instead?
You already do that in the last test though
size_sym_1 = scalar("size_sym_1", dtype="int64") | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
y = expand_dims(x, "batch", size=size_sym_1) | ||
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unused input is a code smell
fn = xr_function([x], y) | ||
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) | ||
res = fn(x_test) | ||
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of the comment explanation write the expected to mimick the Pytensor expression
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" | |
expected = x_test.expand_dims({"d": x_test.sizes["b"]}) |
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" | ||
xr_assert_allclose(res, expected) | ||
|
||
# Test broadcasting with symbolic size from a different tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't see the point of this case
# expected = x_test.expand_dims("batch") # always size 1 | ||
# xr_assert_allclose(res, expected) | ||
|
||
# Test using symbolic size from a reduction operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't see the point of this test
expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b" | ||
xr_assert_allclose(res, expected) | ||
|
||
# Test chaining expand_dims with symbolic sizes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are trying really hard to not just use a symbolic size variable, why?
expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2}) | ||
xr_assert_allclose(res, expected) | ||
|
||
# Test bidirectional broadcasting with symbolic sizes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't see the point of this test
size_sym_1 = scalar("size_sym_1", dtype="int64") | ||
size_sym_2 = scalar("size_sym_2", dtype="int64") | ||
y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2}) | ||
fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On unused input is suspect
result_tensor = expand_dims(x_tensor, new_axis) | ||
else: | ||
# First expand with size 1 | ||
expanded = expand_dims(x_tensor, new_axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small eager optimization, since the new axis is always on the left you don't need the expand_dims step and can skip directly to boradcast_to.
I left some comments above. Rewrite looks good. As we discussed we should redo the tests to use expand_dims as a method (like xarray users would). Also, I suspected xarray allows specifying the size like |
|
||
@register_lower_xtensor | ||
@node_rewriter([ExpandDims]) | ||
def local_expand_dims_reshape(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def local_expand_dims_reshape(fgraph, node): | |
def lower_expand_dims(fgraph, node): |
# Check if size is a valid type before converting | ||
if not ( | ||
isinstance(size, int | np.integer) | ||
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0) | ||
): | ||
raise TypeError( | ||
f"size must be an int or scalar variable, got: {type(size)}" | ||
) | ||
|
||
# Determine shape | ||
try: | ||
static_size = get_scalar_constant_value(size) | ||
except NotScalarConstantError: | ||
static_size = None | ||
|
||
if static_size is not None: | ||
new_shape = (int(static_size), *x.type.shape) | ||
else: | ||
new_shape = (None, *x.type.shape) # symbolic size | ||
|
||
# Convert size to tensor | ||
size = as_xtensor(size, dims=()) | ||
|
||
# Check if size is a constant and validate it | ||
if isinstance(size, Constant) and size.data < 0: | ||
raise ValueError(f"size must be 0 or positive, got: {size.data}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit more clean?
# Check if size is a valid type before converting | |
if not ( | |
isinstance(size, int | np.integer) | |
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0) | |
): | |
raise TypeError( | |
f"size must be an int or scalar variable, got: {type(size)}" | |
) | |
# Determine shape | |
try: | |
static_size = get_scalar_constant_value(size) | |
except NotScalarConstantError: | |
static_size = None | |
if static_size is not None: | |
new_shape = (int(static_size), *x.type.shape) | |
else: | |
new_shape = (None, *x.type.shape) # symbolic size | |
# Convert size to tensor | |
size = as_xtensor(size, dims=()) | |
# Check if size is a constant and validate it | |
if isinstance(size, Constant) and size.data < 0: | |
raise ValueError(f"size must be 0 or positive, got: {size.data}") | |
size = as_xtensor(size, dims=()) | |
if not (size.dtype in integer_dtypes and size.ndim == 0): | |
raise ValueError(f"size should be an integer scalar, got {size.type}") | |
try: | |
static_size = int(get_scalar_constant_value(size)) | |
except NotScalarConstantError: | |
static_size = None | |
# If size is a constant, validate it | |
if static_size is not None and static_size < 0: | |
raise ValueError(f"size must be 0 or positive, got: {static_size}") | |
new_shape = (static_size, *x.type.shape) |
y = expand_dims(x, "batch", size=size_sym_1) | ||
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") | ||
x_test = xr_arange_like(x) | ||
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you already have the function you can use it twice (makes comment before stale)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) | |
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) | |
xr_assert_allclose(fn(x_test, 5), x_test.expand_dims({"batch": 5})) |
Add expand_dims operation for labeled tensors
This PR adds support for the
expand_dims
operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.Key Features
ExpandDims
operation that adds a new dimension to an XTensorVariableImplementation Details
The implementation includes:
New
ExpandDims
class inpytensor/xtensor/shape.py
that handles:Rewriting rule in
pytensor/xtensor/rewriting/shape.py
that:Comprehensive test suite in
tests/xtensor/test_shape.py
covering:Usage Example
Testing
The implementation includes extensive tests that verify:
📚 Documentation preview 📚: https://pytensor--1449.org.readthedocs.build/en/1449/