@@ -157,6 +157,55 @@ def test_2d(self):
157
157
create_diagonal (xp .asarray ([[1 ]]), xp = xp )
158
158
159
159
160
+ class TestExpandDims :
161
+ def test_functionality (self ):
162
+ def _squeeze_all (b : Array ) -> Array :
163
+ """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
164
+ for axis in range (b .ndim ):
165
+ with contextlib .suppress (ValueError ):
166
+ b = xp .squeeze (b , axis = axis )
167
+ return b
168
+
169
+ s = (2 , 3 , 4 , 5 )
170
+ a = xp .empty (s )
171
+ for axis in range (- 5 , 4 ):
172
+ b = expand_dims (a , axis = axis , xp = xp )
173
+ assert b .shape [axis ] == 1
174
+ assert _squeeze_all (b ).shape == s
175
+
176
+ def test_axis_tuple (self ):
177
+ a = xp .empty ((3 , 3 , 3 ))
178
+ assert expand_dims (a , axis = (0 , 1 , 2 ), xp = xp ).shape == (1 , 1 , 1 , 3 , 3 , 3 )
179
+ assert expand_dims (a , axis = (0 , - 1 , - 2 ), xp = xp ).shape == (1 , 3 , 3 , 3 , 1 , 1 )
180
+ assert expand_dims (a , axis = (0 , 3 , 5 ), xp = xp ).shape == (1 , 3 , 3 , 1 , 3 , 1 )
181
+ assert expand_dims (a , axis = (0 , - 3 , - 5 ), xp = xp ).shape == (1 , 1 , 3 , 1 , 3 , 3 )
182
+
183
+ def test_axis_out_of_range (self ):
184
+ s = (2 , 3 , 4 , 5 )
185
+ a = xp .empty (s )
186
+ with pytest .raises (IndexError , match = "out of bounds" ):
187
+ expand_dims (a , axis = - 6 , xp = xp )
188
+ with pytest .raises (IndexError , match = "out of bounds" ):
189
+ expand_dims (a , axis = 5 , xp = xp )
190
+
191
+ a = xp .empty ((3 , 3 , 3 ))
192
+ with pytest .raises (IndexError , match = "out of bounds" ):
193
+ expand_dims (a , axis = (0 , - 6 ), xp = xp )
194
+ with pytest .raises (IndexError , match = "out of bounds" ):
195
+ expand_dims (a , axis = (0 , 5 ), xp = xp )
196
+
197
+ def test_repeated_axis (self ):
198
+ a = xp .empty ((3 , 3 , 3 ))
199
+ with pytest .raises (ValueError , match = "Duplicate dimensions" ):
200
+ expand_dims (a , axis = (1 , 1 ), xp = xp )
201
+
202
+ def test_positive_negative_repeated (self ):
203
+ # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
204
+ a = xp .empty ((2 , 3 , 4 , 5 ))
205
+ with pytest .raises (ValueError , match = "Duplicate dimensions" ):
206
+ expand_dims (a , axis = (3 , - 3 ), xp = xp )
207
+
208
+
160
209
class TestKron :
161
210
def test_basic (self ):
162
211
# Using 0-dimensional array
@@ -222,55 +271,6 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
222
271
assert_equal (k .shape , expected_shape , err_msg = "Unexpected shape from kron" )
223
272
224
273
225
- class TestExpandDims :
226
- def test_functionality (self ):
227
- def _squeeze_all (b : Array ) -> Array :
228
- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
229
- for axis in range (b .ndim ):
230
- with contextlib .suppress (ValueError ):
231
- b = xp .squeeze (b , axis = axis )
232
- return b
233
-
234
- s = (2 , 3 , 4 , 5 )
235
- a = xp .empty (s )
236
- for axis in range (- 5 , 4 ):
237
- b = expand_dims (a , axis = axis , xp = xp )
238
- assert b .shape [axis ] == 1
239
- assert _squeeze_all (b ).shape == s
240
-
241
- def test_axis_tuple (self ):
242
- a = xp .empty ((3 , 3 , 3 ))
243
- assert expand_dims (a , axis = (0 , 1 , 2 ), xp = xp ).shape == (1 , 1 , 1 , 3 , 3 , 3 )
244
- assert expand_dims (a , axis = (0 , - 1 , - 2 ), xp = xp ).shape == (1 , 3 , 3 , 3 , 1 , 1 )
245
- assert expand_dims (a , axis = (0 , 3 , 5 ), xp = xp ).shape == (1 , 3 , 3 , 1 , 3 , 1 )
246
- assert expand_dims (a , axis = (0 , - 3 , - 5 ), xp = xp ).shape == (1 , 1 , 3 , 1 , 3 , 3 )
247
-
248
- def test_axis_out_of_range (self ):
249
- s = (2 , 3 , 4 , 5 )
250
- a = xp .empty (s )
251
- with pytest .raises (IndexError , match = "out of bounds" ):
252
- expand_dims (a , axis = - 6 , xp = xp )
253
- with pytest .raises (IndexError , match = "out of bounds" ):
254
- expand_dims (a , axis = 5 , xp = xp )
255
-
256
- a = xp .empty ((3 , 3 , 3 ))
257
- with pytest .raises (IndexError , match = "out of bounds" ):
258
- expand_dims (a , axis = (0 , - 6 ), xp = xp )
259
- with pytest .raises (IndexError , match = "out of bounds" ):
260
- expand_dims (a , axis = (0 , 5 ), xp = xp )
261
-
262
- def test_repeated_axis (self ):
263
- a = xp .empty ((3 , 3 , 3 ))
264
- with pytest .raises (ValueError , match = "Duplicate dimensions" ):
265
- expand_dims (a , axis = (1 , 1 ), xp = xp )
266
-
267
- def test_positive_negative_repeated (self ):
268
- # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
269
- a = xp .empty ((2 , 3 , 4 , 5 ))
270
- with pytest .raises (ValueError , match = "Duplicate dimensions" ):
271
- expand_dims (a , axis = (3 , - 3 ), xp = xp )
272
-
273
-
274
274
class TestSetDiff1D :
275
275
def test_setdiff1d (self ):
276
276
x1 = xp .asarray ([6 , 5 , 4 , 7 , 1 , 2 , 7 , 4 ])
0 commit comments