Skip to content

Commit 4e4e84e

Browse files
committed
code review
1 parent 8d888d5 commit 4e4e84e

File tree

7 files changed

+44
-47
lines changed

7 files changed

+44
-47
lines changed

array_api_compat/common/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class UniqueInverseResult(NamedTuple):
177177
inverse_indices: Array
178178

179179

180-
def _unique_kwargs(xp: Namespace) -> dict[str, Any]:
180+
def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
181181
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
182182
# trying to parse version numbers, just check if equal_nan is in the
183183
# signature.

array_api_compat/cupy/_aliases.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def astype(
126126

127127
# cupy.count_nonzero does not have keepdims
128128
def count_nonzero(
129-
x: ndarray,
129+
x: Array,
130130
axis=None,
131131
keepdims=False
132-
) -> ndarray:
132+
) -> Array:
133133
result = cp.count_nonzero(x, axis)
134134
if keepdims:
135135
if axis is None:

array_api_compat/cupy/_typing.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,29 @@
33
__all__ = ["Array", "DType", "Device"]
44
_all_ignore = ["cp"]
55

6-
from typing import Union
6+
from typing import Union, TYPE_CHECKING
77

88
import cupy as cp
99
from cupy import ndarray as Array
1010
from cupy.cuda.device import Device
1111

12-
try:
12+
if TYPE_CHECKING:
13+
# NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[]
1314
DType = cp.dtype[
14-
Union[
15-
cp.intp,
16-
cp.int8,
17-
cp.int16,
18-
cp.int32,
19-
cp.int64,
20-
cp.uint8,
21-
cp.uint16,
22-
cp.uint32,
23-
cp.uint64,
24-
cp.float32,
25-
cp.float64,
26-
cp.complex64,
27-
cp.complex128,
28-
cp.bool_,
29-
]
15+
cp.intp
16+
| cp.int8
17+
| cp.int16
18+
| cp.int32
19+
| cp.int64
20+
| cp.uint8
21+
| cp.uint16
22+
| cp.uint32
23+
| cp.uint64
24+
| cp.float32
25+
| cp.float64
26+
| cp.complex64
27+
| cp.complex128
28+
| cp.bool_
3029
]
31-
except TypeError:
32-
# NumPy 1.x on Python 3.9 and 3.10
30+
else:
3331
DType = cp.dtype

array_api_compat/numpy/_aliases.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def astype(
128128
# count_nonzero returns a python int for axis=None and keepdims=False
129129
# https://github.com/numpy/numpy/issues/17562
130130
def count_nonzero(
131-
x : ndarray,
131+
x : Array,
132132
axis=None,
133133
keepdims=False
134-
) -> ndarray:
134+
) -> Array:
135135
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
136136
if axis is None and not keepdims:
137137
return np.asarray(result)

array_api_compat/numpy/_typing.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,29 @@
33
__all__ = ["Array", "DType", "Device"]
44
_all_ignore = ["np"]
55

6-
from typing import Literal, Union
6+
from typing import Literal, TYPE_CHECKING
77

88
import numpy as np
99
from numpy import ndarray as Array
1010

1111
Device = Literal["cpu"]
12-
try:
12+
if TYPE_CHECKING:
13+
# NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[]
1314
DType = np.dtype[
14-
Union[
15-
np.intp,
16-
np.int8,
17-
np.int16,
18-
np.int32,
19-
np.int64,
20-
np.uint8,
21-
np.uint16,
22-
np.uint32,
23-
np.uint64,
24-
np.float32,
25-
np.float64,
26-
np.complex64,
27-
np.complex128,
28-
np.bool_,
29-
]
15+
np.intp
16+
| np.int8
17+
| np.int16
18+
| np.int32
19+
| np.int64
20+
| np.uint8
21+
| np.uint16
22+
| np.uint32
23+
| np.uint64
24+
| np.float32
25+
| np.float64
26+
| np.complex64
27+
| np.complex128
28+
| np.bool
3029
]
31-
except TypeError:
32-
# NumPy 1.x on Python 3.9 and 3.10
30+
else:
3331
DType = np.dtype

array_api_compat/torch/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Optional, Sequence, Tuple, Union
66

77
import torch
8+
89
from .._internal import get_xp
910
from ..common import _aliases
1011
from ._info import __array_namespace_info__

tests/test_all.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import pytest
1818
import typing
1919

20-
TYPING_NAMES = {
20+
TYPING_NAMES = frozenset((
2121
"Array",
2222
"Device",
2323
"DType",
2424
"Namespace",
2525
"NestedSequence",
2626
"SupportsBufferProtocol",
27-
}
27+
))
2828

2929
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
3030
def test_all(library):

0 commit comments

Comments
 (0)