Skip to content

Commit 5667b18

Browse files
committed
Finalize assigning output dimension names in index operation
1 parent ec79aa5 commit 5667b18

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

pytensor/xtensor/indexing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,28 @@ def as_idx_variable(idx):
2323
idx = make_slice(idx)
2424
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
2525
pass
26-
elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str):
26+
elif (
27+
isinstance(idx, tuple)
28+
and len(idx) == 2
29+
and (
30+
isinstance(idx[0], str)
31+
or (
32+
isinstance(idx[0], tuple | list)
33+
and all(isinstance(d, str) for d in idx[0])
34+
)
35+
)
36+
):
2737
# Special case for ("x", array) that xarray supports
28-
# TODO: Check if this can be used to rename existing xarray dimensions or only for numpy
2938
dim, idx = idx
30-
idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,))
39+
if isinstance(idx.type, XTensorType):
40+
raise TypeError(
41+
"Giving a dimension name to an XTensorVariable indexer is not supported"
42+
)
43+
if isinstance(dim, str):
44+
dims = (dim,)
45+
else:
46+
dims = tuple(dim)
47+
idx = xtensor_from_tensor(as_tensor(idx), dims=dims)
3148
else:
3249
# Must be integer indices, we already counted for None and slices
3350
try:

tests/xtensor/test_indexing.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_single_adv_indexing_on_existing_dim():
5252
idx_test = np.array([0, 1, 0, 2], dtype=int)
5353
xidx_test = DataArray(idx_test, dims=("a",))
5454

55-
# Three equivalent ways of indexing a->a
55+
# Equivalent ways of indexing a->a
5656
y = x[idx]
5757
fn = xr_function([x, idx], y)
5858
res = fn(x_test, idx_test)
@@ -65,6 +65,12 @@ def test_single_adv_indexing_on_existing_dim():
6565
expected_res = x_test[(("a", idx_test),)]
6666
xr_assert_allclose(res, expected_res)
6767

68+
y = x[((("a",), idx),)]
69+
fn = xr_function([x, idx], y)
70+
res = fn(x_test, idx_test)
71+
expected_res = x_test[((("a",), idx_test),)]
72+
xr_assert_allclose(res, expected_res)
73+
6874
y = x[xidx]
6975
fn = xr_function([x, xidx], y)
7076
res = fn(x_test, xidx_test)
@@ -81,13 +87,19 @@ def test_single_vector_indexing_on_new_dim():
8187
idx_test = np.array([0, 1, 0, 2], dtype=int)
8288
xidx_test = DataArray(idx_test, dims=("a",))
8389

84-
# Two equivalent ways of indexing a->new_a
90+
# Equivalent ways of indexing a->new_a
8591
y = x[(("new_a", idx),)]
8692
fn = xr_function([x, idx], y)
8793
res = fn(x_test, idx_test)
8894
expected_res = x_test[(("new_a", idx_test),)]
8995
xr_assert_allclose(res, expected_res)
9096

97+
y = x[((["new_a"], idx),)]
98+
fn = xr_function([x, idx], y)
99+
res = fn(x_test, idx_test)
100+
expected_res = x_test[((["new_a"], idx_test),)]
101+
xr_assert_allclose(res, expected_res)
102+
91103
y = x[xidx.rename(a="new_a")]
92104
fn = xr_function([x, xidx], y)
93105
res = fn(x_test, xidx_test)
@@ -176,6 +188,34 @@ def test_matrix_indexing():
176188
xr_assert_allclose(res, expected_res)
177189

178190

191+
def test_assign_multiple_out_dims():
192+
x = xtensor("x", shape=(5, 7), dims=("a", "b"))
193+
idx1 = tensor("idx1", dtype=int, shape=(4, 3))
194+
idx2 = tensor("idx2", dtype=int, shape=(3, 2))
195+
out = x[(("out1", "out2"), idx1), (["out2", "out3"], idx2)]
196+
197+
fn = xr_function([x, idx1, idx2], out)
198+
199+
rng = np.random.default_rng()
200+
x_test = xr_arange_like(x)
201+
idx1_test = rng.binomial(n=4, p=0.5, size=(4, 3))
202+
idx2_test = rng.binomial(n=4, p=0.5, size=(3, 2))
203+
res = fn(x_test, idx1_test, idx2_test)
204+
expected_res = x_test[(("out1", "out2"), idx1_test), (["out2", "out3"], idx2_test)]
205+
xr_assert_allclose(res, expected_res)
206+
207+
208+
def test_assign_dims_xtensor_fails():
209+
x = xtensor("x", shape=(5, 7), dims=("a", "b"))
210+
idx1 = xtensor("idx1", dtype=int, shape=(4,), dims=("c",))
211+
212+
with pytest.raises(
213+
TypeError,
214+
match="Giving a dimension name to an XTensorVariable indexer is not supported",
215+
):
216+
x[("d", idx1),]
217+
218+
179219
class TestVectorizedIndexingNotAllowedToBroadcast:
180220
def test_compile_time_error(self):
181221
x = xtensor(dims=("a", "b"), shape=(3, 5))

0 commit comments

Comments
 (0)