Skip to content

Commit 724e071

Browse files
committed
Add multi-device test for take
1 parent 032f3bb commit 724e071

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

array_api_strict/_creation_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _check_device(device):
3737
if device is not None and not isinstance(device, Device):
3838
raise ValueError(f"Unsupported device {device!r}")
3939

40-
if device not in ALL_DEVICES:
40+
if device is not None and device not in ALL_DEVICES:
4141
raise ValueError(f"Unsupported device {device!r}")
4242

4343
def asarray(

array_api_strict/_indexing_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
2222
raise TypeError("Only integer dtypes are allowed in indexing")
2323
if indices.ndim != 1:
2424
raise ValueError("Only 1-dim indices array is supported")
25-
return Array._new(np.take(x._array, indices._array, axis=axis))
25+
if x.device != indices.device:
26+
raise RuntimeError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
27+
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)

array_api_strict/tests/test_indexing_functions.py

+20
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,23 @@ def test_take_function(x, indices, axis, expected):
2222
indices = xp.asarray(indices)
2323
out = xp.take(x, indices, axis=axis)
2424
assert xp.all(out == xp.asarray(expected))
25+
26+
27+
def test_take_device():
28+
x = xp.asarray([2, 3])
29+
indices = xp.asarray([1, 1, 0])
30+
xp.take(x, indices)
31+
32+
x = xp.asarray([2, 3])
33+
indices = xp.asarray([1, 1, 0], device=xp.Device("device1"))
34+
with pytest.raises(RuntimeError, match="Arrays from two different devices"):
35+
xp.take(x, indices)
36+
37+
x = xp.asarray([2, 3], device=xp.Device("device1"))
38+
indices = xp.asarray([1, 1, 0])
39+
with pytest.raises(RuntimeError, match="Arrays from two different devices"):
40+
xp.take(x, indices)
41+
42+
x = xp.asarray([2, 3], device=xp.Device("device1"))
43+
indices = xp.asarray([1, 1, 0], device=xp.Device("device1"))
44+
xp.take(x, indices)

0 commit comments

Comments
 (0)