Skip to content

Commit 9a87e6f

Browse files
authored
Merge pull request #301 from crusaderky/check_device
MAINT: validate device on numpy and dask
2 parents ba0401e + 37b1c47 commit 9a87e6f

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

array_api_compat/common/_helpers.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -595,11 +595,29 @@ def your_function(x, y):
595595
# backwards compatibility alias
596596
get_namespace = array_namespace
597597

598-
def _check_device(xp, device):
599-
if xp == sys.modules.get('numpy'):
600-
if device not in ["cpu", None]:
598+
599+
def _check_device(bare_xp, device):
600+
"""
601+
Validate dummy device on device-less array backends.
602+
603+
Notes
604+
-----
605+
This function is also invoked by CuPy, which does have multiple devices
606+
if there are multiple GPUs available.
607+
However, CuPy multi-device support is currently impossible
608+
without using the global device or a context manager:
609+
610+
https://github.com/data-apis/array-api-compat/pull/293
611+
"""
612+
if bare_xp is sys.modules.get('numpy'):
613+
if device not in ("cpu", None):
601614
raise ValueError(f"Unsupported device for NumPy: {device!r}")
602615

616+
elif bare_xp is sys.modules.get('dask.array'):
617+
if device not in ("cpu", _DASK_DEVICE, None):
618+
raise ValueError(f"Unsupported device for Dask: {device!r}")
619+
620+
603621
# Placeholder object to represent the dask device
604622
# when the array backend is not the CPU.
605623
# (since it is not easy to tell which device a dask array is on)

array_api_compat/dask/array/_aliases.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
import dask.array as da
2525

26-
from ...common import _aliases, array_namespace
26+
from ...common import _aliases, _helpers, array_namespace
2727
from ...common._typing import (
2828
Array,
2929
Device,
@@ -54,6 +54,7 @@ def astype(
5454
specification for more details.
5555
"""
5656
# TODO: respect device keyword?
57+
_helpers._check_device(da, device)
5758

5859
if not copy and dtype == x.dtype:
5960
return x
@@ -84,6 +85,7 @@ def arange(
8485
specification for more details.
8586
"""
8687
# TODO: respect device keyword?
88+
_helpers._check_device(da, device)
8789

8890
args = [start]
8991
if stop is not None:
@@ -155,6 +157,7 @@ def asarray(
155157
specification for more details.
156158
"""
157159
# TODO: respect device keyword?
160+
_helpers._check_device(da, device)
158161

159162
if isinstance(obj, da.Array):
160163
if dtype is not None and dtype != obj.dtype:

array_api_compat/numpy/_aliases.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, Union
44

55
from .._internal import get_xp
6-
from ..common import _aliases
6+
from ..common import _aliases, _helpers
77
from ..common._typing import NestedSequence, SupportsBufferProtocol
88
from ._info import __array_namespace_info__
99
from ._typing import Array, Device, DType
@@ -97,8 +97,7 @@ def asarray(
9797
See the corresponding documentation in the array library and/or the array API
9898
specification for more details.
9999
"""
100-
if device not in ["cpu", None]:
101-
raise ValueError(f"Unsupported device for NumPy: {device!r}")
100+
_helpers._check_device(np, device)
102101

103102
if hasattr(np, '_CopyMode'):
104103
if copy is None:
@@ -124,6 +123,7 @@ def astype(
124123
copy: bool = True,
125124
device: Optional[Device] = None,
126125
) -> Array:
126+
_helpers._check_device(np, device)
127127
return x.astype(dtype=dtype, copy=copy)
128128

129129

0 commit comments

Comments
 (0)