@@ -85,6 +85,11 @@ def test_5D(self):
85
85
y = atleast_nd (x , ndim = 9 , xp = xp )
86
86
assert_array_equal (y , xp .ones ((1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 )))
87
87
88
+ def test_device (self ):
89
+ device = xp .Device ("device1" )
90
+ x = xp .asarray ([1 , 2 , 3 ], device = device )
91
+ assert atleast_nd (x , ndim = 2 , xp = xp ).device == device
92
+
88
93
89
94
class TestCov :
90
95
def test_basic (self ):
@@ -120,6 +125,11 @@ def test_combination(self):
120
125
assert_allclose (cov (x , xp = xp ), xp .asarray (11.71 ))
121
126
assert_allclose (cov (y , xp = xp ), xp .asarray (2.144133 ), rtol = 1e-6 )
122
127
128
+ def test_device (self ):
129
+ device = xp .Device ("device1" )
130
+ x = xp .asarray ([1 , 2 , 3 ], device = device )
131
+ assert cov (x , xp = xp ).device == device
132
+
123
133
124
134
class TestCreateDiagonal :
125
135
def test_1d (self ):
@@ -156,6 +166,11 @@ def test_2d(self):
156
166
with pytest .raises (ValueError , match = "1-dimensional" ):
157
167
create_diagonal (xp .asarray ([[1 ]]), xp = xp )
158
168
169
+ def test_device (self ):
170
+ device = xp .Device ("device1" )
171
+ x = xp .asarray ([1 , 2 , 3 ], device = device )
172
+ assert create_diagonal (x , xp = xp ).device == device
173
+
159
174
160
175
class TestExpandDims :
161
176
def test_functionality (self ):
@@ -205,6 +220,11 @@ def test_positive_negative_repeated(self):
205
220
with pytest .raises (ValueError , match = "Duplicate dimensions" ):
206
221
expand_dims (a , axis = (3 , - 3 ), xp = xp )
207
222
223
+ def test_device (self ):
224
+ device = xp .Device ("device1" )
225
+ x = xp .asarray ([1 , 2 , 3 ], device = device )
226
+ assert expand_dims (x , axis = 0 , xp = xp ).device == device
227
+
208
228
209
229
class TestKron :
210
230
def test_basic (self ):
@@ -270,6 +290,12 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
270
290
k = kron (a , b , xp = xp )
271
291
assert_equal (k .shape , expected_shape , err_msg = "Unexpected shape from kron" )
272
292
293
+ def test_device (self ):
294
+ device = xp .Device ("device1" )
295
+ x1 = xp .asarray ([1 , 2 , 3 ], device = device )
296
+ x2 = xp .asarray ([4 , 5 ], device = device )
297
+ assert kron (x1 , x2 , xp = xp ).device == device
298
+
273
299
274
300
class TestSetDiff1D :
275
301
def test_setdiff1d (self ):
@@ -298,6 +324,12 @@ def test_assume_unique(self):
298
324
actual = setdiff1d (x1 , x2 , assume_unique = True , xp = xp )
299
325
assert_array_equal (actual , expected )
300
326
327
+ def test_device (self ):
328
+ device = xp .Device ("device1" )
329
+ x1 = xp .asarray ([3 , 8 , 20 ], device = device )
330
+ x2 = xp .asarray ([2 , 3 , 4 ], device = device )
331
+ assert setdiff1d (x1 , x2 , xp = xp ).device == device
332
+
301
333
302
334
class TestSinc :
303
335
def test_simple (self ):
@@ -316,3 +348,8 @@ def test_3d(self):
316
348
expected = xp .zeros ((3 , 3 , 2 ))
317
349
expected [0 , 0 , 0 ] = 1.0
318
350
assert_allclose (sinc (x , xp = xp ), expected , atol = 1e-15 )
351
+
352
+ def test_device (self ):
353
+ device = xp .Device ("device1" )
354
+ x = xp .asarray (0.0 , device = device )
355
+ assert sinc (x , xp = xp ).device == device
0 commit comments