-
Notifications
You must be signed in to change notification settings - Fork 135
Adding expand_dims for xtensor #1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Adding expand_dims for xtensor #1449
Conversation
Now that we have this PR based on the right commit, @ricardoV94 it is ready for a first look. One question: my first draft of this was based on a later commit -- this draft goes back to an earlier commit, and it looks like |
That's the new name, it better represents the kind of rewrites it holds |
@ricardoV94 I think this is a step toward handling symbolic sizes, but there are a couple of place where I'm not sure what the right behavior is. See the comments in Do those tests make sense? Are there more cases that should be covered? |
The simplest test for symbolic expand_dims is: size_new_dim = xtensor("size_new_dim", shape=(), dtype=int)
x = xtensor("x", shape=(3,))
y = x.expand_dims(new_dim=size_new_dim)
xr_function = function([x, size_new_dim], y)
x_test = xr_arange_like(x)
size_new_dim_test = DataArray(np.array(5, dtype=int))
result = xr_function(x_test, size_new_dim_test)
expected_result = x_test.expand_dims(new_dim=size_new_dim_test)
xr_assert_allclose(result, expected_result) Yout can parametrize the test to try default and explicit non-default axis as well. Sidenote, what is an implicit |
@ricardoV94 I've addressed most of your comments on the previous round, and made a first pass at adding support for multiple dimensions. Please take a look at the Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions? |
That's fine. We used that for other Ops and we can revisit later of we want it to be fused |
@ricardoV94 This is ready for another look. The rewrite was a shambles, but I think I have a clearer idea now. |
I left some comments above. Rewrite looks good. As we discussed we should redo the tests to use expand_dims as a method (like xarray users would). Also, I suspected xarray allows specifying the size like |
@ricardoV94 I cleaned up the code as suggested and took a first cut at handling the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. Just the question about the size kwarg and small notes.
@ricardoV94 This is ready for another look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're nearly there. This proved to be more complex than I antecipated.
Small error checks changes and one question. Also I don't think there's any test for the passing sequences as the length of the dimensions?
if not create_index_for_new_dim: | ||
warnings.warn( | ||
"create_index_for_new_dim=False has no effect in pytensor.xtensor", | ||
UserWarning, | ||
stacklevel=2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neither option has an effect to be fair, just don't mention?
) | ||
|
||
# Extract size from dim_kwargs if present | ||
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still here, did you not push yet?
# xarray compatibility: error if a sequence (list/tuple) of dims and size are given | ||
if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1): | ||
raise ValueError("cannot specify both keyword and positional arguments") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also not a thing?
|
||
# Normalize to a dimension-size mapping | ||
if isinstance(dim, str): | ||
dims_dict = {dim: size} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dims_dict = {dim: size} | |
dims_dict = {dim: 1} |
elif isinstance(dim, dict): | ||
dims_dict = {} | ||
for name, val in dim.items(): | ||
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does xarray treat expand_dims(new_dim=np.array(5))
? Does it treat as coordinates or the size? I suppose the latter?
In that case the check here should be (isinstance(val, np.ndarray), and val.ndim > 0)
. We could also consider symbolic variables with (isinstance(val, np.ndarray) or (isinstance(val, Variable) and isinstance(val.type, HasShape)) and val.ndim > 0)
elif isinstance(val, int): | ||
dims_dict[name] = val | ||
else: | ||
dims_dict[name] = val # symbolic/int scalar allowed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ends up being accepted anyway so merge?
elif isinstance(val, int): | |
dims_dict[name] = val | |
else: | |
dims_dict[name] = val # symbolic/int scalar allowed | |
else: | |
dims_dict[name] = val # symbolic/int scalar allowed |
elif isinstance(dim, dict): | ||
dims_dict = {} | ||
for name, val in dim.items(): | ||
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): | |
if isinstance(val, str): | |
raise ValueError(f"The size of a dimension cannot be a string, got {val}) | |
if isinstance(val, Sequence | np.ndarray): |
PS it's so annoying there's no type for non-string sequences |
Add expand_dims operation for labeled tensors
This PR adds support for the
expand_dims
operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.Key Features
ExpandDims
operation that adds a new dimension to an XTensorVariableImplementation Details
The implementation includes:
New
ExpandDims
class inpytensor/xtensor/shape.py
that handles:Rewriting rule in
pytensor/xtensor/rewriting/shape.py
that:Comprehensive test suite in
tests/xtensor/test_shape.py
covering:Usage Example
Testing
The implementation includes extensive tests that verify:
📚 Documentation preview 📚: https://pytensor--1449.org.readthedocs.build/en/1449/