-
Notifications
You must be signed in to change notification settings - Fork 135
Adding expand_dims for xtensor #1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Adding expand_dims for xtensor #1449
Conversation
Now that we have this PR based on the right commit, @ricardoV94 it is ready for a first look. One question: my first draft of this was based on a later commit -- this draft goes back to an earlier commit, and it looks like |
That's the new name, it better represents the kind of rewrites it holds |
pytensor/xtensor/shape.py
Outdated
|
||
def __init__(self, dim, size=1): | ||
self.dims = dim | ||
self.size = size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not allow symbolic sizes, check UnStack for reference
@ricardoV94 I think this is a step toward handling symbolic sizes, but there are a couple of place where I'm not sure what the right behavior is. See the comments in Do those tests make sense? Are there more cases that should be covered? |
The simplest test for symbolic expand_dims is: size_new_dim = xtensor("size_new_dim", shape=(), dtype=int)
x = xtensor("x", shape=(3,))
y = x.expand_dims(new_dim=size_new_dim)
xr_function = function([x, size_new_dim], y)
x_test = xr_arange_like(x)
size_new_dim_test = DataArray(np.array(5, dtype=int))
result = xr_function(x_test, size_new_dim_test)
expected_result = x_test.expand_dims(new_dim=size_new_dim_test)
xr_assert_allclose(result, expected_result) Yout can parametrize the test to try default and explicit non-default axis as well. Sidenote, what is an implicit |
tests/xtensor/test_shape.py
Outdated
|
||
# Duplicate dimension creation | ||
y = expand_dims(x, "new") | ||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match the expected error messages to be sure you are triggering the branch you care about. Sometimes you are testing an earlier error and can't tell because of only checking for ValueError/TypeError
@ricardoV94 I've addressed most of your comments on the previous round, and made a first pass at adding support for multiple dimensions. Please take a look at the Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions? |
def make_node(self, x, size): | ||
x = as_xtensor(x) | ||
|
||
if not isinstance(self.dim, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done already in init
if isinstance(size, int | np.integer): | ||
if size <= 0: | ||
raise ValueError(f"size must be positive, got: {size}") | ||
elif not ( | ||
hasattr(size, "ndim") | ||
and getattr(size, "ndim", None) == 0 # symbolic scalar | ||
): | ||
raise TypeError( | ||
f"size must be an int or scalar variable, got: {type(size)}" | ||
) | ||
|
||
# Convert size to tensor | ||
size = as_tensor(size, ndim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to convert to as_xtensor, since both tensor and xtensor scalars will be accepted then.
Instead of having the check for numpy specifically, check if after conversion you have isinstance(size, Constant) and (size.data <= 0)
and then raise. That will cover more cases. Or actually use the inferred static size variable below to do that check.
Are you sure xarray doesn't support size of zero? It's valid otherwise in numpy
x, size = node.inputs | ||
out = node.outputs[0] | ||
# Lower to tensor.expand_dims(x, axis=0) | ||
from pytensor.tensor import expand_dims as tensor_expand_dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import globally
from pytensor.tensor import broadcast_to | ||
|
||
# Ensure size is positive | ||
expanded = Assert(msg="size must be positive")(expanded, gt(size, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for assert which pollutes the graph, broadcast_to will fail at runtime if size is negative. Size zero is allowed but if it's also allowed in xarray that's fine
expanded = Assert(msg="size must be positive")(expanded, gt(size, 0)) | ||
# If size is not 1, broadcast | ||
try: | ||
static_size = get_scalar_constant_value(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You already inferred static size in make_node so you can use that info from out.type.shape
|
||
# Test behavior with symbolic size > 1 | ||
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. | ||
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true?
That's fine. We used that for other Ops and we can revisit later of we want it to be fused |
Add expand_dims operation for labeled tensors
This PR adds support for the
expand_dims
operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.Key Features
ExpandDims
operation that adds a new dimension to an XTensorVariableImplementation Details
The implementation includes:
New
ExpandDims
class inpytensor/xtensor/shape.py
that handles:Rewriting rule in
pytensor/xtensor/rewriting/shape.py
that:Comprehensive test suite in
tests/xtensor/test_shape.py
covering:Usage Example
Testing
The implementation includes extensive tests that verify:
📚 Documentation preview 📚: https://pytensor--1449.org.readthedocs.build/en/1449/