Skip to content

Commit 5de9f26

Browse files
committed
Merge branch 'main' into cupy_device
2 parents 96d8f5e + 3e5fdc0 commit 5de9f26

20 files changed

+127
-82
lines changed

array_api_compat/common/_aliases.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import inspect
8-
from typing import NamedTuple, Optional, Sequence, Tuple, Union
8+
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
99

1010
from ._typing import Array, Device, DType, Namespace
1111
from ._helpers import (
@@ -609,13 +609,30 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
609609
out[xp.isnan(x)] = xp.nan
610610
return out[()]
611611

612+
613+
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
614+
# It is surprisingly difficult to recognize a dtype apart from an array.
615+
# np.int64 is not the same as np.asarray(1).dtype!
616+
try:
617+
return xp.finfo(type_)
618+
except (ValueError, TypeError):
619+
return xp.finfo(type_.dtype)
620+
621+
622+
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
623+
try:
624+
return xp.iinfo(type_)
625+
except (ValueError, TypeError):
626+
return xp.iinfo(type_.dtype)
627+
628+
612629
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
613630
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
614631
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
615632
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
616633
'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
617634
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
618635
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
619-
'unstack', 'sign']
636+
'unstack', 'sign', 'finfo', 'iinfo']
620637

621638
_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']

array_api_compat/common/_helpers.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,44 @@ def your_function(x, y):
598598
get_namespace = array_namespace
599599

600600

601+
def _device_ctx(
602+
bare_xp: Namespace, device: Device, like: Array | None = None
603+
) -> Generator[None]:
604+
"""Context manager which changes the current device in CuPy.
605+
606+
Used internally by array creation functions in common._aliases.
607+
"""
608+
if device is None:
609+
if like is None:
610+
return contextlib.nullcontext()
611+
device = _device(like)
612+
613+
if bare_xp is sys.modules.get('numpy'):
614+
if device != "cpu":
615+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
616+
return contextlib.nullcontext()
617+
618+
if bare_xp is sys.modules.get('dask.array'):
619+
if device not in ("cpu", _DASK_DEVICE):
620+
raise ValueError(f"Unsupported device for Dask: {device!r}")
621+
return contextlib.nullcontext()
622+
623+
if bare_xp is sys.modules.get('cupy'):
624+
if not isinstance(device, bare_xp.cuda.Device):
625+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
626+
return device
627+
628+
# PyTorch doesn't have a "current device" context manager and you
629+
# can't use array creation functions from common._aliases.
630+
raise AssertionError("unreachable") # pragma: nocover
631+
632+
633+
def _check_device(bare_xp: Namespace, device: Device) -> None:
634+
"""Validate dummy device on device-less array backends."""
635+
with _device_ctx(bare_xp, device):
636+
pass
637+
638+
601639
# Placeholder object to represent the dask device
602640
# when the array backend is not the CPU.
603641
# (since it is not easy to tell which device a dask array is on)
@@ -607,7 +645,6 @@ def __repr__(self):
607645

608646
_DASK_DEVICE = _dask_device()
609647

610-
611648
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
612649
# or cupy.ndarray. They are not included in array objects of this library
613650
# because this library just reuses the respective ndarray classes without
@@ -799,43 +836,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
799836
return x.to_device(device, stream=stream)
800837

801838

802-
def _device_ctx(
803-
bare_xp: Namespace, device: Device, like: Array | None = None
804-
) -> Generator[None]:
805-
"""Context manager which changes the current device in CuPy.
806-
807-
Used internally by array creation functions in common._aliases.
808-
"""
809-
if device is None:
810-
if like is None:
811-
return contextlib.nullcontext()
812-
device = _device(like)
813-
814-
if bare_xp is sys.modules.get('numpy'):
815-
if device != "cpu":
816-
raise ValueError(f"Unsupported device for NumPy: {device!r}")
817-
return contextlib.nullcontext()
818-
819-
if bare_xp is sys.modules.get('dask.array'):
820-
if device not in ("cpu", _DASK_DEVICE):
821-
raise ValueError(f"Unsupported device for Dask: {device!r}")
822-
return contextlib.nullcontext()
823-
824-
if bare_xp is sys.modules.get('cupy'):
825-
if not isinstance(device, bare_xp.cuda.Device):
826-
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
827-
return device
828-
829-
# PyTorch doesn't have a "current device" context manager and you
830-
# can't use array creation functions from common._aliases.
831-
raise AssertionError("unreachable") # pragma: nocover
832-
833-
834-
def _check_device(bare_xp: Namespace, device: Device) -> None:
835-
with _device_ctx(bare_xp, device):
836-
pass
837-
838-
839839
def size(x: Array) -> int | None:
840840
"""
841841
Return the total number of elements of x.

array_api_compat/common/_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ def trace(
174174
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
175175
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
176176
'trace']
177+
178+
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']

array_api_compat/cupy/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88

99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
11-
1211
__import__(__package__ + '.fft')
1312

14-
from ..common._helpers import * # noqa: F401,F403
15-
1613
__array_api_version__ = '2024.12'

array_api_compat/cupy/_aliases.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363
sign = get_xp(cp)(_aliases.sign)
64+
finfo = get_xp(cp)(_aliases.finfo)
65+
iinfo = get_xp(cp)(_aliases.iinfo)
6466

6567

6668
# asarray also adds the copy keyword, which is not present in numpy 1.0.
@@ -87,7 +89,7 @@ def asarray(
8789
if copy is False:
8890
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
8991

90-
like = obj if _helpers.is_cupy_array(obj) else None
92+
like = obj if isinstance(obj, cp.ndarray) else None
9193
with _helpers._device_ctx(cp, device, like=like):
9294
if copy is None:
9395
return cp.asarray(obj, dtype=dtype, **kwargs)

array_api_compat/dask/array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55

66
__array_api_version__ = '2024.12'
77

8+
# See the comment in the numpy __init__.py
89
__import__(__package__ + '.linalg')
910
__import__(__package__ + '.fft')

array_api_compat/dask/array/_aliases.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import numpy as np
66
from numpy import (
77
# dtypes
8-
iinfo,
9-
finfo,
108
bool_ as bool,
119
float32,
1210
float64,
@@ -133,6 +131,8 @@ def arange(
133131
matmul = get_xp(np)(_aliases.matmul)
134132
tensordot = get_xp(np)(_aliases.tensordot)
135133
sign = get_xp(np)(_aliases.sign)
134+
finfo = get_xp(np)(_aliases.finfo)
135+
iinfo = get_xp(np)(_aliases.iinfo)
136136

137137

138138
# asarray also adds the copy keyword, which is not present in numpy 1.0.
@@ -346,10 +346,9 @@ def count_nonzero(
346346
'__array_namespace_info__', 'asarray', 'astype', 'acos',
347347
'acosh', 'asin', 'asinh', 'atan', 'atan2',
348348
'atanh', 'bitwise_left_shift', 'bitwise_invert',
349-
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
349+
'bitwise_right_shift', 'concat', 'pow', 'can_cast',
350350
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
351-
'uint8', 'uint16', 'uint32', 'uint64',
352-
'complex64', 'complex128', 'iinfo', 'finfo',
351+
'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128',
353352
'can_cast', 'count_nonzero', 'result_type']
354353

355354
_all_ignore = ["array_namespace", "get_xp", "da", "np"]

array_api_compat/numpy/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,8 @@
1414
# It doesn't overwrite np.linalg from above. The import is generated
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
17-
1817
__import__(__package__ + '.fft')
1918

2019
from .linalg import matrix_transpose, vecdot # noqa: F401
2120

22-
from ..common._helpers import * # noqa: F403
23-
24-
try:
25-
# Used in asarray(). Not present in older versions.
26-
from numpy import _CopyMode # noqa: F401
27-
except ImportError:
28-
pass
29-
3021
__array_api_version__ = '2024.12'

array_api_compat/numpy/_aliases.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363
sign = get_xp(np)(_aliases.sign)
64+
finfo = get_xp(np)(_aliases.finfo)
65+
iinfo = get_xp(np)(_aliases.iinfo)
6466

6567

6668
def _supports_buffer_protocol(obj):
@@ -86,7 +88,7 @@ def asarray(
8688
*,
8789
dtype: Optional[DType] = None,
8890
device: Optional[Device] = None,
89-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
91+
copy: Optional[Union[bool, np._CopyMode]] = None,
9092
**kwargs,
9193
) -> Array:
9294
"""

array_api_compat/torch/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
or 'cpu' in n
1010
or 'backward' in n):
1111
continue
12-
exec(n + ' = torch.' + n)
12+
exec(f"{n} = torch.{n}")
13+
del n
1314

1415
# These imports may overwrite names from the import * above.
1516
from ._aliases import * # noqa: F403
1617

1718
# See the comment in the numpy __init__.py
1819
__import__(__package__ + '.linalg')
19-
2020
__import__(__package__ + '.fft')
2121

22-
from ..common._helpers import * # noqa: F403
23-
2422
__array_api_version__ = '2024.12'

0 commit comments

Comments
 (0)