Skip to content

Commit 38105ba

Browse files
committed
TST: add device tests
1 parent ff5c6c2 commit 38105ba

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

src/array_api_extra/_funcs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
193193
err_msg = "`x` must be 1-dimensional."
194194
raise ValueError(err_msg)
195195
n = x.shape[0] + abs(offset)
196-
diag = xp.zeros(n**2, dtype=x.dtype)
196+
diag = xp.zeros(n**2, dtype=x.dtype, device=x.device)
197197
i = offset if offset >= 0 else abs(offset) * n
198198
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
199199
return xp.reshape(diag, (n, n))
@@ -516,6 +516,6 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
516516
raise ValueError(err_msg)
517517
# no scalars in `where` - array-api#807
518518
y = xp.pi * xp.where(
519-
x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)
519+
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
520520
)
521521
return xp.sin(y) / y

tests/test_funcs.py

+37
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def test_5D(self):
8585
y = atleast_nd(x, ndim=9, xp=xp)
8686
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
8787

88+
def test_device(self):
89+
device = xp.Device("device1")
90+
x = xp.asarray([1, 2, 3], device=device)
91+
assert atleast_nd(x, ndim=2, xp=xp).device == device
92+
8893

8994
class TestCov:
9095
def test_basic(self):
@@ -120,6 +125,11 @@ def test_combination(self):
120125
assert_allclose(cov(x, xp=xp), xp.asarray(11.71))
121126
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)
122127

128+
def test_device(self):
129+
device = xp.Device("device1")
130+
x = xp.asarray([1, 2, 3], device=device)
131+
assert cov(x, xp=xp).device == device
132+
123133

124134
class TestCreateDiagonal:
125135
def test_1d(self):
@@ -156,6 +166,11 @@ def test_2d(self):
156166
with pytest.raises(ValueError, match="1-dimensional"):
157167
create_diagonal(xp.asarray([[1]]), xp=xp)
158168

169+
def test_device(self):
170+
device = xp.Device("device1")
171+
x = xp.asarray([1, 2, 3], device=device)
172+
assert create_diagonal(x, xp=xp).device == device
173+
159174

160175
class TestExpandDims:
161176
def test_functionality(self):
@@ -205,6 +220,11 @@ def test_positive_negative_repeated(self):
205220
with pytest.raises(ValueError, match="Duplicate dimensions"):
206221
expand_dims(a, axis=(3, -3), xp=xp)
207222

223+
def test_device(self):
224+
device = xp.Device("device1")
225+
x = xp.asarray([1, 2, 3], device=device)
226+
assert expand_dims(x, axis=0, xp=xp).device == device
227+
208228

209229
class TestKron:
210230
def test_basic(self):
@@ -270,6 +290,12 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
270290
k = kron(a, b, xp=xp)
271291
assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")
272292

293+
def test_device(self):
294+
device = xp.Device("device1")
295+
x1 = xp.asarray([1, 2, 3], device=device)
296+
x2 = xp.asarray([4, 5], device=device)
297+
assert kron(x1, x2, xp=xp).device == device
298+
273299

274300
class TestSetDiff1D:
275301
def test_setdiff1d(self):
@@ -298,6 +324,12 @@ def test_assume_unique(self):
298324
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
299325
assert_array_equal(actual, expected)
300326

327+
def test_device(self):
328+
device = xp.Device("device1")
329+
x1 = xp.asarray([3, 8, 20], device=device)
330+
x2 = xp.asarray([2, 3, 4], device=device)
331+
assert setdiff1d(x1, x2, xp=xp).device == device
332+
301333

302334
class TestSinc:
303335
def test_simple(self):
@@ -316,3 +348,8 @@ def test_3d(self):
316348
expected = xp.zeros((3, 3, 2))
317349
expected[0, 0, 0] = 1.0
318350
assert_allclose(sinc(x, xp=xp), expected, atol=1e-15)
351+
352+
def test_device(self):
353+
device = xp.Device("device1")
354+
x = xp.asarray(0.0, device=device)
355+
assert sinc(x, xp=xp).device == device

tests/test_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ def test_no_invert_assume_unique(self, x2: Array):
2222
expected = xp.asarray([True, True, False])
2323
actual = in1d(x1, x2, xp=xp)
2424
assert_array_equal(actual, expected)
25+
26+
def test_device(self):
27+
device = xp.Device("device1")
28+
x1 = xp.asarray([3, 8, 20], device=device)
29+
x2 = xp.asarray([2, 3, 4], device=device)
30+
assert in1d(x1, x2, xp=xp).device == device

0 commit comments

Comments
 (0)