@@ -52,7 +52,7 @@ def test_single_adv_indexing_on_existing_dim():
52
52
idx_test = np .array ([0 , 1 , 0 , 2 ], dtype = int )
53
53
xidx_test = DataArray (idx_test , dims = ("a" ,))
54
54
55
- # Three equivalent ways of indexing a->a
55
+ # Equivalent ways of indexing a->a
56
56
y = x [idx ]
57
57
fn = xr_function ([x , idx ], y )
58
58
res = fn (x_test , idx_test )
@@ -65,6 +65,12 @@ def test_single_adv_indexing_on_existing_dim():
65
65
expected_res = x_test [(("a" , idx_test ),)]
66
66
xr_assert_allclose (res , expected_res )
67
67
68
+ y = x [((("a" ,), idx ),)]
69
+ fn = xr_function ([x , idx ], y )
70
+ res = fn (x_test , idx_test )
71
+ expected_res = x_test [((("a" ,), idx_test ),)]
72
+ xr_assert_allclose (res , expected_res )
73
+
68
74
y = x [xidx ]
69
75
fn = xr_function ([x , xidx ], y )
70
76
res = fn (x_test , xidx_test )
@@ -81,13 +87,19 @@ def test_single_vector_indexing_on_new_dim():
81
87
idx_test = np .array ([0 , 1 , 0 , 2 ], dtype = int )
82
88
xidx_test = DataArray (idx_test , dims = ("a" ,))
83
89
84
- # Two equivalent ways of indexing a->new_a
90
+ # Equivalent ways of indexing a->new_a
85
91
y = x [(("new_a" , idx ),)]
86
92
fn = xr_function ([x , idx ], y )
87
93
res = fn (x_test , idx_test )
88
94
expected_res = x_test [(("new_a" , idx_test ),)]
89
95
xr_assert_allclose (res , expected_res )
90
96
97
+ y = x [((["new_a" ], idx ),)]
98
+ fn = xr_function ([x , idx ], y )
99
+ res = fn (x_test , idx_test )
100
+ expected_res = x_test [((["new_a" ], idx_test ),)]
101
+ xr_assert_allclose (res , expected_res )
102
+
91
103
y = x [xidx .rename (a = "new_a" )]
92
104
fn = xr_function ([x , xidx ], y )
93
105
res = fn (x_test , xidx_test )
@@ -176,6 +188,34 @@ def test_matrix_indexing():
176
188
xr_assert_allclose (res , expected_res )
177
189
178
190
191
+ def test_assign_multiple_out_dims ():
192
+ x = xtensor ("x" , shape = (5 , 7 ), dims = ("a" , "b" ))
193
+ idx1 = tensor ("idx1" , dtype = int , shape = (4 , 3 ))
194
+ idx2 = tensor ("idx2" , dtype = int , shape = (3 , 2 ))
195
+ out = x [(("out1" , "out2" ), idx1 ), (["out2" , "out3" ], idx2 )]
196
+
197
+ fn = xr_function ([x , idx1 , idx2 ], out )
198
+
199
+ rng = np .random .default_rng ()
200
+ x_test = xr_arange_like (x )
201
+ idx1_test = rng .binomial (n = 4 , p = 0.5 , size = (4 , 3 ))
202
+ idx2_test = rng .binomial (n = 4 , p = 0.5 , size = (3 , 2 ))
203
+ res = fn (x_test , idx1_test , idx2_test )
204
+ expected_res = x_test [(("out1" , "out2" ), idx1_test ), (["out2" , "out3" ], idx2_test )]
205
+ xr_assert_allclose (res , expected_res )
206
+
207
+
208
+ def test_assign_dims_xtensor_fails ():
209
+ x = xtensor ("x" , shape = (5 , 7 ), dims = ("a" , "b" ))
210
+ idx1 = xtensor ("idx1" , dtype = int , shape = (4 ,), dims = ("c" ,))
211
+
212
+ with pytest .raises (
213
+ TypeError ,
214
+ match = "Giving a dimension name to an XTensorVariable indexer is not supported" ,
215
+ ):
216
+ x [("d" , idx1 ),]
217
+
218
+
179
219
class TestVectorizedIndexingNotAllowedToBroadcast :
180
220
def test_compile_time_error (self ):
181
221
x = xtensor (dims = ("a" , "b" ), shape = (3 , 5 ))
0 commit comments