Skip to content

Commit 34ce3f6

Browse files
crusaderkyjorenham
andauthoredMar 31, 2025··
TYP: type annotations (#135)
* TYP: type annotations * Python 3.9 fixes * self-review * code review * Update array_api_strict/_array_object.py Co-authored-by: Joren Hammudoglu <[email protected]> * Apply suggestions from code review Co-authored-by: Joren Hammudoglu <[email protected]> * fix * code review * fixes * normalize order * Fancy indexing in `__getitem__` signature * verbose Python scalar types --------- Co-authored-by: Joren Hammudoglu <[email protected]>
1 parent a8f567a commit 34ce3f6

24 files changed

+807
-716
lines changed
 

‎array_api_strict/__init__.py‎

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
1717
"""
1818

19+
from types import ModuleType
20+
1921
__all__ = []
2022

2123
# Warning: __array_api_version__ could change globally with
@@ -325,12 +327,16 @@
325327
ArrayAPIStrictFlags,
326328
)
327329

328-
__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags']
330+
__all__ += [
331+
'set_array_api_strict_flags',
332+
'get_array_api_strict_flags',
333+
'reset_array_api_strict_flags',
334+
'ArrayAPIStrictFlags',
335+
'__version__',
336+
]
329337

330338
try:
331-
from . import _version
332-
__version__ = _version.__version__
333-
del _version
339+
from ._version import __version__ # type: ignore[import-not-found,unused-ignore]
334340
except ImportError:
335341
__version__ = "unknown"
336342

@@ -340,7 +346,7 @@
340346
# use __getattr__. Note that linalg and fft are dynamically added and removed
341347
# from __all__ in set_array_api_strict_flags.
342348

343-
def __getattr__(name):
349+
def __getattr__(name: str) -> ModuleType:
344350
if name in ['linalg', 'fft']:
345351
if name in get_array_api_strict_flags()['enabled_extensions']:
346352
if name == 'linalg':

‎array_api_strict/_array_object.py‎

Lines changed: 163 additions & 132 deletions
Large diffs are not rendered by default.

‎array_api_strict/_constants.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
inf = np.inf
55
nan = np.nan
66
pi = np.pi
7-
newaxis = np.newaxis
7+
newaxis: None = np.newaxis

‎array_api_strict/_creation_functions.py‎

Lines changed: 101 additions & 104 deletions
Large diffs are not rendered by default.

‎array_api_strict/_data_type_functions.py‎

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,37 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
4-
from ._creation_functions import _check_device
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
7+
from ._array_object import Array, Device
8+
from ._creation_functions import Undef, _check_device, _undef
59
from ._dtypes import (
6-
_DType,
10+
DType,
711
_all_dtypes,
812
_boolean_dtypes,
9-
_signed_integer_dtypes,
10-
_unsigned_integer_dtypes,
11-
_integer_dtypes,
12-
_real_floating_dtypes,
1313
_complex_floating_dtypes,
14+
_integer_dtypes,
1415
_numeric_dtypes,
16+
_real_floating_dtypes,
1517
_result_type,
18+
_signed_integer_dtypes,
19+
_unsigned_integer_dtypes,
1620
)
1721
from ._flags import get_array_api_strict_flags
1822

19-
from dataclasses import dataclass
20-
from typing import TYPE_CHECKING
21-
22-
if TYPE_CHECKING:
23-
from typing import List, Tuple, Union, Optional
24-
from ._typing import Dtype, Device
25-
26-
import numpy as np
27-
28-
# Use to emulate the asarray(device) argument not existing in 2022.12
29-
_default = object()
3023

3124
# Note: astype is a function, not an array method as in NumPy.
3225
def astype(
33-
x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default
26+
x: Array,
27+
dtype: DType,
28+
/,
29+
*,
30+
copy: bool = True,
31+
# _default is used to emulate the device argument not existing in 2022.12
32+
device: Device | Undef | None = _undef,
3433
) -> Array:
35-
if device is not _default:
34+
if device is not _undef:
3635
if get_array_api_strict_flags()['api_version'] >= '2023.12':
3736
_check_device(device)
3837
else:
@@ -52,7 +51,7 @@ def astype(
5251
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device)
5352

5453

55-
def broadcast_arrays(*arrays: Array) -> List[Array]:
54+
def broadcast_arrays(*arrays: Array) -> list[Array]:
5655
"""
5756
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
5857
@@ -65,7 +64,7 @@ def broadcast_arrays(*arrays: Array) -> List[Array]:
6564
]
6665

6766

68-
def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
67+
def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array:
6968
"""
7069
Array API compatible wrapper for :py:func:`np.broadcast_to <numpy.broadcast_to>`.
7170
@@ -76,7 +75,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
7675
return Array._new(np.broadcast_to(x._array, shape), device=x.device)
7776

7877

79-
def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
78+
def can_cast(from_: DType | Array, to: DType, /) -> bool:
8079
"""
8180
Array API compatible wrapper for :py:func:`np.can_cast <numpy.can_cast>`.
8281
@@ -112,26 +111,25 @@ class finfo_object:
112111
max: float
113112
min: float
114113
smallest_normal: float
115-
dtype: Dtype
114+
dtype: DType
116115

117116

118117
@dataclass
119118
class iinfo_object:
120119
bits: int
121120
max: int
122121
min: int
123-
dtype: Dtype
122+
dtype: DType
124123

125124

126-
def finfo(type: Union[Dtype, Array], /) -> finfo_object:
125+
def finfo(type: DType | Array, /) -> finfo_object:
127126
"""
128127
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
129128
130129
See its docstring for more information.
131130
"""
132-
if isinstance(type, _DType):
133-
type = type._np_dtype
134-
fi = np.finfo(type)
131+
np_type = type._array if isinstance(type, Array) else type._np_dtype
132+
fi = np.finfo(np_type)
135133
# Note: The types of the float data here are float, whereas in NumPy they
136134
# are scalars of the corresponding float dtype.
137135
return finfo_object(
@@ -140,35 +138,33 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
140138
float(fi.max),
141139
float(fi.min),
142140
float(fi.smallest_normal),
143-
fi.dtype,
141+
DType(fi.dtype),
144142
)
145143

146144

147-
def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
145+
def iinfo(type: DType | Array, /) -> iinfo_object:
148146
"""
149147
Array API compatible wrapper for :py:func:`np.iinfo <numpy.iinfo>`.
150148
151149
See its docstring for more information.
152150
"""
153-
if isinstance(type, _DType):
154-
type = type._np_dtype
155-
ii = np.iinfo(type)
156-
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
151+
np_type = type._array if isinstance(type, Array) else type._np_dtype
152+
ii = np.iinfo(np_type)
153+
return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype))
157154

158155

159156
# Note: isdtype is a new function from the 2022.12 array API specification.
160-
def isdtype(
161-
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
162-
) -> bool:
157+
def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool:
163158
"""
164-
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
159+
Returns a boolean indicating whether a provided dtype is of a specified
160+
data type ``kind``.
165161
166162
See
167163
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
168164
for more details
169165
"""
170-
if not isinstance(dtype, _DType):
171-
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
166+
if not isinstance(dtype, DType):
167+
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
172168

173169
if isinstance(kind, tuple):
174170
# Disallow nested tuples
@@ -197,7 +193,10 @@ def isdtype(
197193
else:
198194
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
199195

200-
def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype:
196+
197+
def result_type(
198+
*arrays_and_dtypes: DType | Array | bool | int | float | complex,
199+
) -> DType:
201200
"""
202201
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
203202
@@ -219,15 +218,15 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, boo
219218
A.append(a)
220219

221220
# remove python scalars
222-
A = [a for a in A if not isinstance(a, (bool, int, float, complex))]
221+
B = [a for a in A if not isinstance(a, (bool, int, float, complex))]
223222

224-
if len(A) == 0:
223+
if len(B) == 0:
225224
raise ValueError("at least one array or dtype is required")
226-
elif len(A) == 1:
227-
result = A[0]
225+
elif len(B) == 1:
226+
result = B[0]
228227
else:
229-
t = A[0]
230-
for t2 in A[1:]:
228+
t = B[0]
229+
for t2 in B[1:]:
231230
t = _result_type(t, t2)
232231
result = t
233232

‎array_api_strict/_dtypes.py‎

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1+
from __future__ import annotations
2+
3+
import builtins
14
import warnings
5+
from typing import Any, Final
26

37
import numpy as np
8+
import numpy.typing as npt
49

510
# Note: we wrap the NumPy dtype objects in a bare class, so that none of the
611
# additional methods and behaviors of NumPy dtype objects are exposed.
712

8-
class _DType:
9-
def __init__(self, np_dtype):
10-
np_dtype = np.dtype(np_dtype)
11-
self._np_dtype = np_dtype
1213

13-
def __repr__(self):
14+
class DType:
15+
_np_dtype: Final[np.dtype[Any]]
16+
__slots__ = ("_np_dtype", "__weakref__")
17+
18+
def __init__(self, np_dtype: npt.DTypeLike):
19+
self._np_dtype = np.dtype(np_dtype)
20+
21+
def __repr__(self) -> str:
1422
return f"array_api_strict.{self._np_dtype.name}"
1523

16-
def __eq__(self, other):
24+
def __eq__(self, other: object) -> builtins.bool:
1725
# See https://github.com/numpy/numpy/pull/25370/files#r1423259515.
1826
# Avoid the user error of array_api_strict.float32 == numpy.float32,
1927
# which gives False. Making == error is probably too egregious, so
@@ -26,33 +34,38 @@ def __eq__(self, other):
2634
a NumPy native dtype object, but you probably don't want to do this. \
2735
array_api_strict dtype objects compare unequal to their NumPy equivalents. \
2836
Such cross-library comparison is not supported by the standard.""",
29-
stacklevel=2)
30-
if not isinstance(other, _DType):
37+
stacklevel=2,
38+
)
39+
if not isinstance(other, DType):
3140
return NotImplemented
3241
return self._np_dtype == other._np_dtype
3342

34-
def __hash__(self):
43+
def __hash__(self) -> int:
3544
# Note: this is not strictly required
3645
# (https://github.com/data-apis/array-api/issues/582), but makes the
3746
# dtype objects much easier to work with here and elsewhere if they
3847
# can be used as dict keys.
3948
return hash(self._np_dtype)
4049

4150

42-
int8 = _DType("int8")
43-
int16 = _DType("int16")
44-
int32 = _DType("int32")
45-
int64 = _DType("int64")
46-
uint8 = _DType("uint8")
47-
uint16 = _DType("uint16")
48-
uint32 = _DType("uint32")
49-
uint64 = _DType("uint64")
50-
float32 = _DType("float32")
51-
float64 = _DType("float64")
52-
complex64 = _DType("complex64")
53-
complex128 = _DType("complex128")
51+
def _np_dtype(dtype: DType | None) -> np.dtype[Any] | None:
52+
return dtype._np_dtype if dtype is not None else None
53+
54+
55+
int8 = DType("int8")
56+
int16 = DType("int16")
57+
int32 = DType("int32")
58+
int64 = DType("int64")
59+
uint8 = DType("uint8")
60+
uint16 = DType("uint16")
61+
uint32 = DType("uint32")
62+
uint64 = DType("uint64")
63+
float32 = DType("float32")
64+
float64 = DType("float64")
65+
complex64 = DType("complex64")
66+
complex128 = DType("complex128")
5467
# Note: This name is changed
55-
bool = _DType("bool")
68+
bool = DType("bool")
5669

5770
_all_dtypes = (
5871
int8,
@@ -212,7 +225,7 @@ def __hash__(self):
212225
}
213226

214227

215-
def _result_type(type1, type2):
228+
def _result_type(type1: DType, type2: DType) -> DType:
216229
if (type1, type2) in _promotion_table:
217230
return _promotion_table[type1, type2]
218231
raise TypeError(f"{type1} and {type2} cannot be type promoted together")

‎array_api_strict/_elementwise_functions.py‎

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,50 @@
11
from __future__ import annotations
22

3+
import numpy as np
4+
5+
from ._array_object import Array
6+
from ._creation_functions import asarray
7+
from ._data_type_functions import broadcast_to, iinfo
38
from ._dtypes import (
49
_boolean_dtypes,
5-
_floating_dtypes,
6-
_real_floating_dtypes,
710
_complex_floating_dtypes,
11+
_dtype_categories,
12+
_floating_dtypes,
813
_integer_dtypes,
914
_integer_or_boolean_dtypes,
10-
_real_numeric_dtypes,
1115
_numeric_dtypes,
16+
_real_floating_dtypes,
17+
_real_numeric_dtypes,
1218
_result_type,
13-
_dtype_categories,
1419
)
15-
from ._array_object import Array
1620
from ._flags import requires_api_version
17-
from ._creation_functions import asarray
18-
from ._data_type_functions import broadcast_to, iinfo
1921
from ._helpers import _maybe_normalize_py_scalars
2022

21-
from typing import Optional, Union
22-
23-
import numpy as np
24-
2523

2624
def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):
2725
"""Base implementation of a binary function, `func_name`, defined for
28-
dtypes from `dtype_category`
26+
dtypes from `dtype_category`
2927
"""
3028
x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name)
3129

3230
if x1.device != x2.device:
33-
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
31+
raise ValueError(
32+
f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined."
33+
)
3434
# Call result type here just to raise on disallowed type combinations
3535
_result_type(x1.dtype, x2.dtype)
3636
x1, x2 = Array._normalize_two_args(x1, x2)
3737
return Array._new(np_func(x1._array, x2._array), device=x1.device)
3838

3939

40-
_binary_docstring_template=\
41-
"""
40+
_binary_docstring_template = """
4241
Array API compatible wrapper for :py:func:`np.%s <numpy.%s>`.
4342
4443
See its docstring for more information.
4544
"""
4645

4746

48-
def create_binary_func(func_name, dtype_category, np_func):
47+
def _create_binary_func(func_name, dtype_category, np_func):
4948
def inner(x1, x2, /) -> Array:
5049
return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func)
5150
return inner
@@ -58,7 +57,7 @@ def inner(x1, x2, /) -> Array:
5857
"real numeric": "int | float | Array",
5958
"numeric": "int | float | complex | Array",
6059
"integer": "int | Array",
61-
"integer or boolean": "int | bool | Array",
60+
"integer or boolean": "bool | int | Array",
6261
"boolean": "bool | Array",
6362
"real floating-point": "float | Array",
6463
"complex floating-point": "complex | Array",
@@ -75,7 +74,7 @@ def inner(x1, x2, /) -> Array:
7574
"bitwise_xor": "integer or boolean",
7675
"_bitwise_left_shift": "integer", # leading underscore deliberate
7776
"_bitwise_right_shift": "integer",
78-
# XXX: copysign: real fp or numeric?
77+
# XXX: copysign: real fp or numeric?
7978
"copysign": "real floating-point",
8079
"divide": "floating-point",
8180
"equal": "all",
@@ -105,7 +104,7 @@ def inner(x1, x2, /) -> Array:
105104
"atan2": "arctan2",
106105
"_bitwise_left_shift": "left_shift",
107106
"_bitwise_right_shift": "right_shift",
108-
"pow": "power"
107+
"pow": "power",
109108
}
110109

111110

@@ -117,7 +116,7 @@ def inner(x1, x2, /) -> Array:
117116
numpy_name = _numpy_renames.get(func_name, func_name)
118117
np_func = getattr(np, numpy_name)
119118

120-
func = create_binary_func(func_name, dtype_category, np_func)
119+
func = _create_binary_func(func_name, dtype_category, np_func)
121120
func.__name__ = func_name
122121

123122
func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name)
@@ -153,7 +152,7 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array:
153152

154153

155154
# clean up to not pollute the namespace
156-
del func, create_binary_func
155+
del func, _create_binary_func
157156

158157

159158
def abs(x: Array, /) -> Array:
@@ -271,8 +270,8 @@ def ceil(x: Array, /) -> Array:
271270
def clip(
272271
x: Array,
273272
/,
274-
min: Optional[Union[int, float, Array]] = None,
275-
max: Optional[Union[int, float, Array]] = None,
273+
min: Array | int | float | None = None,
274+
max: Array | int | float | None = None,
276275
) -> Array:
277276
"""
278277
Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`.
@@ -351,6 +350,7 @@ def clip(
351350

352351
def _isscalar(a):
353352
return isinstance(a, (int, float, type(None)))
353+
354354
min_shape = () if _isscalar(min) else min.shape
355355
max_shape = () if _isscalar(max) else max.shape
356356

@@ -584,6 +584,7 @@ def reciprocal(x: Array, /) -> Array:
584584
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
585585
return Array._new(np.reciprocal(x._array), device=x.device)
586586

587+
587588
def round(x: Array, /) -> Array:
588589
"""
589590
Array API compatible wrapper for :py:func:`np.round <numpy.round>`.

‎array_api_strict/_fft.py‎

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from collections.abc import Sequence
4+
from typing import Literal
45

5-
if TYPE_CHECKING:
6-
from typing import Union, Optional, Literal
7-
from ._typing import Device, Dtype as DType
8-
from collections.abc import Sequence
6+
import numpy as np
97

8+
from ._array_object import ALL_DEVICES, Array, Device
9+
from ._data_type_functions import astype
1010
from ._dtypes import (
11+
DType,
12+
_complex_floating_dtypes,
1113
_floating_dtypes,
1214
_real_floating_dtypes,
13-
_complex_floating_dtypes,
14-
float32,
1515
complex64,
16+
float32,
1617
)
17-
from ._array_object import Array, ALL_DEVICES
18-
from ._data_type_functions import astype
1918
from ._flags import requires_extension
2019

21-
import numpy as np
2220

2321
@requires_extension('fft')
2422
def fft(
2523
x: Array,
2624
/,
2725
*,
28-
n: Optional[int] = None,
26+
n: int | None = None,
2927
axis: int = -1,
3028
norm: Literal["backward", "ortho", "forward"] = "backward",
3129
) -> Array:
@@ -48,7 +46,7 @@ def ifft(
4846
x: Array,
4947
/,
5048
*,
51-
n: Optional[int] = None,
49+
n: int | None = None,
5250
axis: int = -1,
5351
norm: Literal["backward", "ortho", "forward"] = "backward",
5452
) -> Array:
@@ -71,8 +69,8 @@ def fftn(
7169
x: Array,
7270
/,
7371
*,
74-
s: Sequence[int] = None,
75-
axes: Sequence[int] = None,
72+
s: Sequence[int] | None = None,
73+
axes: Sequence[int] | None = None,
7674
norm: Literal["backward", "ortho", "forward"] = "backward",
7775
) -> Array:
7876
"""
@@ -94,8 +92,8 @@ def ifftn(
9492
x: Array,
9593
/,
9694
*,
97-
s: Sequence[int] = None,
98-
axes: Sequence[int] = None,
95+
s: Sequence[int] | None = None,
96+
axes: Sequence[int] | None = None,
9997
norm: Literal["backward", "ortho", "forward"] = "backward",
10098
) -> Array:
10199
"""
@@ -117,7 +115,7 @@ def rfft(
117115
x: Array,
118116
/,
119117
*,
120-
n: Optional[int] = None,
118+
n: int | None = None,
121119
axis: int = -1,
122120
norm: Literal["backward", "ortho", "forward"] = "backward",
123121
) -> Array:
@@ -140,7 +138,7 @@ def irfft(
140138
x: Array,
141139
/,
142140
*,
143-
n: Optional[int] = None,
141+
n: int | None = None,
144142
axis: int = -1,
145143
norm: Literal["backward", "ortho", "forward"] = "backward",
146144
) -> Array:
@@ -163,8 +161,8 @@ def rfftn(
163161
x: Array,
164162
/,
165163
*,
166-
s: Sequence[int] = None,
167-
axes: Sequence[int] = None,
164+
s: Sequence[int] | None = None,
165+
axes: Sequence[int] | None = None,
168166
norm: Literal["backward", "ortho", "forward"] = "backward",
169167
) -> Array:
170168
"""
@@ -186,8 +184,8 @@ def irfftn(
186184
x: Array,
187185
/,
188186
*,
189-
s: Sequence[int] = None,
190-
axes: Sequence[int] = None,
187+
s: Sequence[int] | None = None,
188+
axes: Sequence[int] | None = None,
191189
norm: Literal["backward", "ortho", "forward"] = "backward",
192190
) -> Array:
193191
"""
@@ -209,7 +207,7 @@ def hfft(
209207
x: Array,
210208
/,
211209
*,
212-
n: Optional[int] = None,
210+
n: int | None = None,
213211
axis: int = -1,
214212
norm: Literal["backward", "ortho", "forward"] = "backward",
215213
) -> Array:
@@ -232,7 +230,7 @@ def ihfft(
232230
x: Array,
233231
/,
234232
*,
235-
n: Optional[int] = None,
233+
n: int | None = None,
236234
axis: int = -1,
237235
norm: Literal["backward", "ortho", "forward"] = "backward",
238236
) -> Array:
@@ -256,8 +254,8 @@ def fftfreq(
256254
/,
257255
*,
258256
d: float = 1.0,
259-
dtype: Optional[DType] = None,
260-
device: Optional[Device] = None
257+
dtype: DType | None = None,
258+
device: Device | None = None
261259
) -> Array:
262260
"""
263261
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
@@ -280,8 +278,8 @@ def rfftfreq(
280278
/,
281279
*,
282280
d: float = 1.0,
283-
dtype: Optional[DType] = None,
284-
device: Optional[Device] = None
281+
dtype: DType | None = None,
282+
device: Device | None = None
285283
) -> Array:
286284
"""
287285
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
@@ -299,7 +297,7 @@ def rfftfreq(
299297
return Array._new(np_result, device=device)
300298

301299
@requires_extension('fft')
302-
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
300+
def fftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array:
303301
"""
304302
Array API compatible wrapper for :py:func:`np.fft.fftshift <numpy.fft.fftshift>`.
305303
@@ -310,7 +308,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
310308
return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device)
311309

312310
@requires_extension('fft')
313-
def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
311+
def ifftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array:
314312
"""
315313
Array API compatible wrapper for :py:func:`np.fft.ifftshift <numpy.fft.ifftshift>`.
316314

‎array_api_strict/_flags.py‎

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,32 @@
1212
1313
"""
1414

15+
from __future__ import annotations
16+
1517
import functools
1618
import os
1719
import warnings
20+
from collections.abc import Callable
21+
from types import TracebackType
22+
from typing import TYPE_CHECKING, Any, Collection, TypeVar, cast
1823

1924
import array_api_strict
2025

26+
if TYPE_CHECKING:
27+
# TODO import from typing (requires Python >= 3.10)
28+
from typing_extensions import ParamSpec
29+
30+
P = ParamSpec("P")
31+
32+
T = TypeVar("T")
33+
_CallableT = TypeVar("_CallableT", bound=Callable[..., object])
34+
35+
2136
supported_versions = (
2237
"2021.12",
2338
"2022.12",
2439
"2023.12",
25-
"2024.12"
40+
"2024.12",
2641
)
2742

2843
draft_version = "2025.12"
@@ -43,19 +58,23 @@
4358
"fft": "2022.12",
4459
}
4560

46-
ENABLED_EXTENSIONS = default_extensions = (
61+
default_extensions: tuple[str, ...] = (
4762
"linalg",
4863
"fft",
4964
)
65+
ENABLED_EXTENSIONS = default_extensions
66+
67+
5068
# Public functions
5169

70+
5271
def set_array_api_strict_flags(
5372
*,
54-
api_version=None,
55-
boolean_indexing=None,
56-
data_dependent_shapes=None,
57-
enabled_extensions=None,
58-
):
73+
api_version: str | None = None,
74+
boolean_indexing: bool | None = None,
75+
data_dependent_shapes: bool | None = None,
76+
enabled_extensions: Collection[str] | None = None,
77+
) -> None:
5978
"""
6079
Set the array-api-strict flags to the specified values.
6180
@@ -178,7 +197,8 @@ def set_array_api_strict_flags(
178197
draft_version=draft_version,
179198
)
180199

181-
def get_array_api_strict_flags():
200+
201+
def get_array_api_strict_flags() -> dict[str, Any]:
182202
"""
183203
Get the current array-api-strict flags.
184204
@@ -228,7 +248,7 @@ def get_array_api_strict_flags():
228248
}
229249

230250

231-
def reset_array_api_strict_flags():
251+
def reset_array_api_strict_flags() -> None:
232252
"""
233253
Reset the array-api-strict flags to their default values.
234254
@@ -300,8 +320,19 @@ class ArrayAPIStrictFlags:
300320
reset_array_api_strict_flags: Reset the flags to their default values.
301321
302322
"""
303-
def __init__(self, *, api_version=None, boolean_indexing=None,
304-
data_dependent_shapes=None, enabled_extensions=None):
323+
324+
kwargs: dict[str, Any]
325+
old_flags: dict[str, Any]
326+
__slots__ = ("kwargs", "old_flags")
327+
328+
def __init__(
329+
self,
330+
*,
331+
api_version: str | None = None,
332+
boolean_indexing: bool | None = None,
333+
data_dependent_shapes: bool | None = None,
334+
enabled_extensions: Collection[str] | None = None,
335+
):
305336
self.kwargs = {
306337
"api_version": api_version,
307338
"boolean_indexing": boolean_indexing,
@@ -310,12 +341,19 @@ def __init__(self, *, api_version=None, boolean_indexing=None,
310341
}
311342
self.old_flags = get_array_api_strict_flags()
312343

313-
def __enter__(self):
344+
def __enter__(self) -> None:
314345
set_array_api_strict_flags(**self.kwargs)
315346

316-
def __exit__(self, exc_type, exc_value, traceback):
347+
def __exit__(
348+
self,
349+
exc_type: type[BaseException] | None,
350+
exc_value: BaseException | None,
351+
traceback: TracebackType | None,
352+
/,
353+
) -> None:
317354
set_array_api_strict_flags(**self.old_flags)
318355

356+
319357
# Private functions
320358

321359
ENVIRONMENT_VARIABLES = [
@@ -325,8 +363,9 @@ def __exit__(self, exc_type, exc_value, traceback):
325363
"ARRAY_API_STRICT_ENABLED_EXTENSIONS",
326364
]
327365

328-
def set_flags_from_environment():
329-
kwargs = {}
366+
367+
def set_flags_from_environment() -> None:
368+
kwargs: dict[str, Any] = {}
330369
if "ARRAY_API_STRICT_API_VERSION" in os.environ:
331370
kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"]
332371

@@ -346,41 +385,49 @@ def set_flags_from_environment():
346385
# linalg and fft to __all__
347386
set_array_api_strict_flags(**kwargs)
348387

388+
349389
set_flags_from_environment()
350390

351391
# Decorators
352392

353-
def requires_api_version(version):
354-
def decorator(func):
393+
394+
def requires_api_version(version: str) -> Callable[[_CallableT], _CallableT]:
395+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
355396
@functools.wraps(func)
356-
def wrapper(*args, **kwargs):
397+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
357398
if version > API_VERSION:
358399
raise RuntimeError(
359400
f"The function {func.__name__} requires API version {version} or later, "
360401
f"but the current API version for array-api-strict is {API_VERSION}"
361402
)
362403
return func(*args, **kwargs)
404+
363405
return wrapper
364-
return decorator
365406

366-
def requires_data_dependent_shapes(func):
407+
return cast(Callable[[_CallableT], _CallableT], decorator)
408+
409+
410+
def requires_data_dependent_shapes(func: Callable[P, T]) -> Callable[P, T]:
367411
@functools.wraps(func)
368-
def wrapper(*args, **kwargs):
412+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
369413
if not DATA_DEPENDENT_SHAPES:
370414
raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
371415
return func(*args, **kwargs)
372416
return wrapper
373417

374-
def requires_extension(extension):
375-
def decorator(func):
418+
419+
def requires_extension(extension: str) -> Callable[[_CallableT], _CallableT]:
420+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
376421
@functools.wraps(func)
377-
def wrapper(*args, **kwargs):
422+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
378423
if extension not in ENABLED_EXTENSIONS:
379424
if extension == 'linalg' \
380425
and func.__name__ in ['matmul', 'tensordot',
381426
'matrix_transpose', 'vecdot']:
382427
raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.")
383428
raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict")
384429
return func(*args, **kwargs)
430+
385431
return wrapper
386-
return decorator
432+
433+
return cast(Callable[[_CallableT], _CallableT], decorator)

‎array_api_strict/_helpers.py‎

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1-
"""Private helper routines.
2-
"""
1+
"""Private helper routines."""
32

4-
from ._flags import get_array_api_strict_flags
3+
from __future__ import annotations
4+
5+
from ._array_object import Array
56
from ._dtypes import _dtype_categories
7+
from ._flags import get_array_api_strict_flags
68

79
_py_scalars = (bool, int, float, complex)
810

911

10-
def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name):
11-
12+
def _maybe_normalize_py_scalars(
13+
x1: Array | bool | int | float | complex,
14+
x2: Array | bool | int | float | complex,
15+
dtype_category: str,
16+
func_name: str,
17+
) -> tuple[Array, Array]:
1218
flags = get_array_api_strict_flags()
1319
if flags["api_version"] < "2024.12":
1420
# scalars will fail at the call site
15-
return x1, x2
21+
return x1, x2 # type: ignore[return-value]
1622

1723
_allowed_dtypes = _dtype_categories[dtype_category]
1824

@@ -34,4 +40,3 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name):
3440
raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). "
3541
f"Got {x1.dtype} and {x2.dtype}.")
3642
return x1, x2
37-

‎array_api_strict/_indexing_functions.py‎

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from __future__ import annotations
22

3+
import numpy as np
4+
35
from ._array_object import Array
46
from ._dtypes import _integer_dtypes
57
from ._flags import requires_api_version
68

7-
from typing import TYPE_CHECKING
8-
9-
if TYPE_CHECKING:
10-
from typing import Optional
119

12-
import numpy as np
13-
14-
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
10+
def take(x: Array, indices: Array, /, *, axis: int | None = None) -> Array:
1511
"""
1612
Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
1713
@@ -27,6 +23,7 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
2723
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
2824
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)
2925

26+
3027
@requires_api_version('2024.12')
3128
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
3229
"""

‎array_api_strict/_info.py‎

Lines changed: 64 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
4-
53
import numpy as np
64

7-
if TYPE_CHECKING:
8-
from typing import Optional, Union, Tuple, List
9-
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities
10-
11-
from ._array_object import ALL_DEVICES, CPU_DEVICE
5+
from . import _dtypes as dt
6+
from ._array_object import ALL_DEVICES, CPU_DEVICE, Device
127
from ._flags import get_array_api_strict_flags, requires_api_version
13-
from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
8+
from ._typing import Capabilities, DataTypes, DefaultDataTypes
9+
1410

1511
@requires_api_version('2023.12')
1612
class __array_namespace_info__:
1713
@requires_api_version('2023.12')
1814
def capabilities(self) -> Capabilities:
1915
flags = get_array_api_strict_flags()
20-
res = {"boolean indexing": flags['boolean_indexing'],
21-
"data-dependent shapes": flags['data_dependent_shapes'],
22-
}
16+
res: Capabilities = { # type: ignore[typeddict-item]
17+
"boolean indexing": flags['boolean_indexing'],
18+
"data-dependent shapes": flags['data_dependent_shapes'],
19+
}
2320
if flags['api_version'] >= '2024.12':
2421
# maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will
2522
# drop support for NumPy 1 but for now, just compute the number
@@ -36,104 +33,104 @@ def capabilities(self) -> Capabilities:
3633
return res
3734

3835
@requires_api_version('2023.12')
39-
def default_device(self) -> device:
36+
def default_device(self) -> Device:
4037
return CPU_DEVICE
4138

4239
@requires_api_version('2023.12')
4340
def default_dtypes(
4441
self,
4542
*,
46-
device: Optional[device] = None,
43+
device: Device | None = None,
4744
) -> DefaultDataTypes:
4845
return {
49-
"real floating": float64,
50-
"complex floating": complex128,
51-
"integral": int64,
52-
"indexing": int64,
46+
"real floating": dt.float64,
47+
"complex floating": dt.complex128,
48+
"integral": dt.int64,
49+
"indexing": dt.int64,
5350
}
5451

5552
@requires_api_version('2023.12')
5653
def dtypes(
5754
self,
5855
*,
59-
device: Optional[device] = None,
60-
kind: Optional[Union[str, Tuple[str, ...]]] = None,
56+
device: Device | None = None,
57+
kind: str | tuple[str, ...] | None = None,
6158
) -> DataTypes:
6259
if kind is None:
6360
return {
64-
"bool": bool,
65-
"int8": int8,
66-
"int16": int16,
67-
"int32": int32,
68-
"int64": int64,
69-
"uint8": uint8,
70-
"uint16": uint16,
71-
"uint32": uint32,
72-
"uint64": uint64,
73-
"float32": float32,
74-
"float64": float64,
75-
"complex64": complex64,
76-
"complex128": complex128,
61+
"bool": dt.bool,
62+
"int8": dt.int8,
63+
"int16": dt.int16,
64+
"int32": dt.int32,
65+
"int64": dt.int64,
66+
"uint8": dt.uint8,
67+
"uint16": dt.uint16,
68+
"uint32": dt.uint32,
69+
"uint64": dt.uint64,
70+
"float32": dt.float32,
71+
"float64": dt.float64,
72+
"complex64": dt.complex64,
73+
"complex128": dt.complex128,
7774
}
7875
if kind == "bool":
79-
return {"bool": bool}
76+
return {"bool": dt.bool}
8077
if kind == "signed integer":
8178
return {
82-
"int8": int8,
83-
"int16": int16,
84-
"int32": int32,
85-
"int64": int64,
79+
"int8": dt.int8,
80+
"int16": dt.int16,
81+
"int32": dt.int32,
82+
"int64": dt.int64,
8683
}
8784
if kind == "unsigned integer":
8885
return {
89-
"uint8": uint8,
90-
"uint16": uint16,
91-
"uint32": uint32,
92-
"uint64": uint64,
86+
"uint8": dt.uint8,
87+
"uint16": dt.uint16,
88+
"uint32": dt.uint32,
89+
"uint64": dt.uint64,
9390
}
9491
if kind == "integral":
9592
return {
96-
"int8": int8,
97-
"int16": int16,
98-
"int32": int32,
99-
"int64": int64,
100-
"uint8": uint8,
101-
"uint16": uint16,
102-
"uint32": uint32,
103-
"uint64": uint64,
93+
"int8": dt.int8,
94+
"int16": dt.int16,
95+
"int32": dt.int32,
96+
"int64": dt.int64,
97+
"uint8": dt.uint8,
98+
"uint16": dt.uint16,
99+
"uint32": dt.uint32,
100+
"uint64": dt.uint64,
104101
}
105102
if kind == "real floating":
106103
return {
107-
"float32": float32,
108-
"float64": float64,
104+
"float32": dt.float32,
105+
"float64": dt.float64,
109106
}
110107
if kind == "complex floating":
111108
return {
112-
"complex64": complex64,
113-
"complex128": complex128,
109+
"complex64": dt.complex64,
110+
"complex128": dt.complex128,
114111
}
115112
if kind == "numeric":
116113
return {
117-
"int8": int8,
118-
"int16": int16,
119-
"int32": int32,
120-
"int64": int64,
121-
"uint8": uint8,
122-
"uint16": uint16,
123-
"uint32": uint32,
124-
"uint64": uint64,
125-
"float32": float32,
126-
"float64": float64,
127-
"complex64": complex64,
128-
"complex128": complex128,
114+
"int8": dt.int8,
115+
"int16": dt.int16,
116+
"int32": dt.int32,
117+
"int64": dt.int64,
118+
"uint8": dt.uint8,
119+
"uint16": dt.uint16,
120+
"uint32": dt.uint32,
121+
"uint64": dt.uint64,
122+
"float32": dt.float32,
123+
"float64": dt.float64,
124+
"complex64": dt.complex64,
125+
"complex128": dt.complex128,
129126
}
130127
if isinstance(kind, tuple):
131-
res = {}
128+
res: DataTypes = {}
132129
for k in kind:
133130
res.update(self.dtypes(kind=k))
134131
return res
135132
raise ValueError(f"unsupported kind: {kind!r}")
136133

137134
@requires_api_version('2023.12')
138-
def devices(self) -> List[device]:
135+
def devices(self) -> list[Device]:
139136
return list(ALL_DEVICES)

‎array_api_strict/_linalg.py‎

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,25 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from functools import partial
5+
from typing import Literal, NamedTuple
46

5-
from ._dtypes import (
6-
_floating_dtypes,
7-
_numeric_dtypes,
8-
float32,
9-
complex64,
10-
complex128,
11-
)
7+
import numpy as np
8+
import numpy.linalg
9+
10+
from ._array_object import Array
1211
from ._data_type_functions import finfo
13-
from ._manipulation_functions import reshape
12+
from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128
1413
from ._elementwise_functions import conj
15-
from ._array_object import Array
16-
from ._flags import requires_extension, get_array_api_strict_flags
14+
from ._flags import get_array_api_strict_flags, requires_extension
15+
from ._manipulation_functions import reshape
16+
from ._statistical_functions import _np_dtype_sumprod
1717

1818
try:
19-
from numpy._core.numeric import normalize_axis_tuple
19+
from numpy._core.numeric import normalize_axis_tuple # type: ignore[attr-defined]
2020
except ImportError:
21-
from numpy.core.numeric import normalize_axis_tuple
21+
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
2222

23-
from typing import TYPE_CHECKING
24-
if TYPE_CHECKING:
25-
from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype
26-
27-
from typing import NamedTuple
28-
29-
import numpy.linalg
30-
import numpy as np
3123

3224
class EighResult(NamedTuple):
3325
eigenvalues: Array
@@ -175,7 +167,13 @@ def inv(x: Array, /) -> Array:
175167
# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
176168
# literals.
177169
@requires_extension('linalg')
178-
def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: # noqa: F821
170+
def matrix_norm(
171+
x: Array,
172+
/,
173+
*,
174+
keepdims: bool = False,
175+
ord: float | Literal["fro", "nuc"] | None = "fro",
176+
) -> Array: # noqa: F821
179177
"""
180178
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
181179
@@ -186,7 +184,10 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int,
186184
if x.dtype not in _floating_dtypes:
187185
raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
188186

189-
return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), device=x.device)
187+
return Array._new(
188+
np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord),
189+
device=x.device,
190+
)
190191

191192

192193
@requires_extension('linalg')
@@ -206,7 +207,7 @@ def matrix_power(x: Array, n: int, /) -> Array:
206207

207208
# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
208209
@requires_extension('linalg')
209-
def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
210+
def matrix_rank(x: Array, /, *, rtol: float | Array | None = None) -> Array:
210211
"""
211212
Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
212213
@@ -218,13 +219,12 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A
218219
raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
219220
S = np.linalg.svd(x._array, compute_uv=False)
220221
if rtol is None:
221-
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * finfo(S.dtype).eps
222+
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
222223
else:
223-
if isinstance(rtol, Array):
224-
rtol = rtol._array
224+
rtol_np = rtol._array if isinstance(rtol, Array) else np.asarray(rtol)
225225
# Note: this is different from np.linalg.matrix_rank, which does not multiply
226226
# the tolerance by the largest singular value.
227-
tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
227+
tol = S.max(axis=-1, keepdims=True) * rtol_np[..., np.newaxis]
228228
return Array._new(np.count_nonzero(S > tol, axis=-1), device=x.device)
229229

230230

@@ -252,7 +252,7 @@ def outer(x1: Array, x2: Array, /) -> Array:
252252

253253
# Note: the keyword argument name rtol is different from np.linalg.pinv
254254
@requires_extension('linalg')
255-
def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
255+
def pinv(x: Array, /, *, rtol: float | Array | None = None) -> Array:
256256
"""
257257
Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
258258
@@ -267,9 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
267267
# default tolerance by max(M, N).
268268
if rtol is None:
269269
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
270-
if isinstance(rtol, Array):
271-
rtol = rtol._array
272-
return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device)
270+
rtol_np = rtol._array if isinstance(rtol, Array) else rtol
271+
return Array._new(np.linalg.pinv(x._array, rcond=rtol_np), device=x.device)
273272

274273
@requires_extension('linalg')
275274
def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821
@@ -312,14 +311,14 @@ def slogdet(x: Array, /) -> SlogdetResult:
312311

313312
# To workaround this, the below is the code from np.linalg.solve except
314313
# only calling solve1 in the exactly 1D case.
315-
def _solve(a, b):
314+
def _solve(a: np.ndarray, b: np.ndarray) -> np.ndarray:
316315
try:
317-
from numpy.linalg._linalg import (
316+
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
318317
_makearray, _assert_stacked_2d, _assert_stacked_square,
319318
_commonType, isComplexType, _raise_linalgerror_singular
320319
)
321320
except ImportError:
322-
from numpy.linalg.linalg import (
321+
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
323322
_makearray, _assert_stacked_2d, _assert_stacked_square,
324323
_commonType, isComplexType, _raise_linalgerror_singular
325324
)
@@ -382,14 +381,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
382381
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
383382
# np.linalg.svd(compute_uv=False).
384383
@requires_extension('linalg')
385-
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
384+
def svdvals(x: Array, /) -> Array:
386385
if x.dtype not in _floating_dtypes:
387386
raise TypeError('Only floating-point dtypes are allowed in svdvals')
388387
return Array._new(np.linalg.svd(x._array, compute_uv=False), device=x.device)
389388

390389
# Note: trace is the numpy top-level namespace, not np.linalg
391390
@requires_extension('linalg')
392-
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array:
391+
def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
393392
"""
394393
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
395394
@@ -398,27 +397,28 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr
398397
if x.dtype not in _numeric_dtypes:
399398
raise TypeError('Only numeric dtypes are allowed in trace')
400399

401-
# Note: trace() works the same as sum() and prod() (see
402-
# _statistical_functions.py)
403-
if dtype is None:
404-
if get_array_api_strict_flags()['api_version'] < '2023.12':
405-
if x.dtype == float32:
406-
dtype = np.float64
407-
elif x.dtype == complex64:
408-
dtype = np.complex128
409-
else:
410-
dtype = dtype._np_dtype
400+
# Note: trace() works the same as sum() and prod() (see _statistical_functions.py)
401+
np_dtype = _np_dtype_sumprod(x, dtype)
402+
411403
# Note: trace always operates on the last two axes, whereas np.trace
412404
# operates on the first two axes by default
413-
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)), device=x.device)
405+
res = np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=np_dtype)
406+
return Array._new(np.asarray(res), device=x.device)
414407

415408
# Note: the name here is different from norm(). The array API norm is split
416409
# into matrix_norm and vector_norm().
417410

418411
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
419412
# -np.inf]]] but Literal does not support floating-point literals.
420413
@requires_extension('linalg')
421-
def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
414+
def vector_norm(
415+
x: Array,
416+
/,
417+
*,
418+
axis: int | tuple[int, ...] | None = None,
419+
keepdims: bool = False,
420+
ord: int | float = 2,
421+
) -> Array:
422422
"""
423423
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
424424
@@ -456,8 +456,8 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
456456
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
457457
# above to avoid matrix norm logic.
458458
shape = list(x.shape)
459-
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
460-
for i in _axis:
459+
axis_tup = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
460+
for i in axis_tup:
461461
shape[i] = 1
462462
res = reshape(res, tuple(shape))
463463

@@ -480,7 +480,13 @@ def matmul(x1: Array, x2: Array, /) -> Array:
480480

481481
# Note: tensordot is the numpy top-level namespace but not in np.linalg
482482
@requires_extension('linalg')
483-
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
483+
def tensordot(
484+
x1: Array,
485+
x2: Array,
486+
/,
487+
*,
488+
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
489+
) -> Array:
484490
from ._linear_algebra_functions import tensordot
485491
return tensordot(x1, x2, axes=axes)
486492

‎array_api_strict/_linear_algebra_functions.py‎

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77

88
from __future__ import annotations
99

10-
from ._dtypes import _numeric_dtypes
10+
from collections.abc import Sequence
11+
12+
import numpy as np
13+
import numpy.linalg
14+
1115
from ._array_object import Array
16+
from ._dtypes import _numeric_dtypes
1217
from ._flags import get_array_api_strict_flags
1318

14-
from typing import TYPE_CHECKING
15-
if TYPE_CHECKING:
16-
from ._typing import Sequence, Tuple, Union
17-
18-
import numpy.linalg
19-
import numpy as np
2019

2120
# Note: matmul is the numpy top-level namespace but not in np.linalg
2221
def matmul(x1: Array, x2: Array, /) -> Array:
@@ -38,7 +37,13 @@ def matmul(x1: Array, x2: Array, /) -> Array:
3837
# Note: tensordot is the numpy top-level namespace but not in np.linalg
3938

4039
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
41-
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
40+
def tensordot(
41+
x1: Array,
42+
x2: Array,
43+
/,
44+
*,
45+
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
46+
) -> Array:
4247
# Note: the restriction to numeric dtypes only is different from
4348
# np.tensordot.
4449
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:

‎array_api_strict/_manipulation_functions.py‎

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from __future__ import annotations
22

3+
import numpy as np
4+
35
from ._array_object import Array
46
from ._creation_functions import asarray
57
from ._data_type_functions import astype, result_type
68
from ._dtypes import _integer_dtypes, int64, uint64
7-
from ._flags import requires_api_version, get_array_api_strict_flags
8-
9-
from typing import TYPE_CHECKING
9+
from ._flags import get_array_api_strict_flags, requires_api_version
1010

11-
if TYPE_CHECKING:
12-
from typing import List, Optional, Tuple, Union
13-
14-
import numpy as np
1511

1612
# Note: the function name is different here
1713
def concat(
18-
arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
14+
arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0
1915
) -> Array:
2016
"""
2117
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -29,8 +25,11 @@ def concat(
2925
raise ValueError("concat inputs must all be on the same device")
3026
result_device = arrays[0].device
3127

32-
arrays = tuple(a._array for a in arrays)
33-
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device)
28+
np_arrays = tuple(a._array for a in arrays)
29+
return Array._new(
30+
np.concatenate(np_arrays, axis=axis, dtype=dtype._np_dtype),
31+
device=result_device,
32+
)
3433

3534

3635
def expand_dims(x: Array, /, *, axis: int) -> Array:
@@ -42,7 +41,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
4241
return Array._new(np.expand_dims(x._array, axis), device=x.device)
4342

4443

45-
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
44+
def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
4645
"""
4746
Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
4847
@@ -53,8 +52,8 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
5352
@requires_api_version('2023.12')
5453
def moveaxis(
5554
x: Array,
56-
source: Union[int, Tuple[int, ...]],
57-
destination: Union[int, Tuple[int, ...]],
55+
source: int | tuple[int, ...],
56+
destination: int | tuple[int, ...],
5857
/,
5958
) -> Array:
6059
"""
@@ -66,7 +65,7 @@ def moveaxis(
6665

6766
# Note: The function name is different here (see also matrix_transpose).
6867
# Unlike transpose(), the axes argument is required.
69-
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
68+
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
7069
"""
7170
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
7271
@@ -77,10 +76,10 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
7776
@requires_api_version('2023.12')
7877
def repeat(
7978
x: Array,
80-
repeats: Union[int, Array],
79+
repeats: int | Array,
8180
/,
8281
*,
83-
axis: Optional[int] = None,
82+
axis: int | None = None,
8483
) -> Array:
8584
"""
8685
Array API compatible wrapper for :py:func:`np.repeat <numpy.repeat>`.
@@ -108,12 +107,9 @@ def repeat(
108107
repeats = astype(repeats, int64)
109108
return Array._new(np.repeat(x._array, repeats._array, axis=axis), device=x.device)
110109

110+
111111
# Note: the optional argument is called 'shape', not 'newshape'
112-
def reshape(x: Array,
113-
/,
114-
shape: Tuple[int, ...],
115-
*,
116-
copy: Optional[bool] = None) -> Array:
112+
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
117113
"""
118114
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
119115
@@ -135,9 +131,9 @@ def reshape(x: Array,
135131
def roll(
136132
x: Array,
137133
/,
138-
shift: Union[int, Tuple[int, ...]],
134+
shift: int | tuple[int, ...],
139135
*,
140-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
136+
axis: int | tuple[int, ...] | None = None,
141137
) -> Array:
142138
"""
143139
Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
@@ -147,7 +143,7 @@ def roll(
147143
return Array._new(np.roll(x._array, shift, axis=axis), device=x.device)
148144

149145

150-
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
146+
def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array:
151147
"""
152148
Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
153149
@@ -161,7 +157,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
161157
return Array._new(np.squeeze(x._array, axis=axis), device=x.device)
162158

163159

164-
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
160+
def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array:
165161
"""
166162
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
167163
@@ -172,12 +168,12 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
172168
if len({a.device for a in arrays}) > 1:
173169
raise ValueError("concat inputs must all be on the same device")
174170
result_device = arrays[0].device
175-
arrays = tuple(a._array for a in arrays)
176-
return Array._new(np.stack(arrays, axis=axis), device=result_device)
171+
np_arrays = tuple(a._array for a in arrays)
172+
return Array._new(np.stack(np_arrays, axis=axis), device=result_device)
177173

178174

179175
@requires_api_version('2023.12')
180-
def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
176+
def tile(x: Array, repetitions: tuple[int, ...], /) -> Array:
181177
"""
182178
Array API compatible wrapper for :py:func:`np.tile <numpy.tile>`.
183179
@@ -190,7 +186,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
190186

191187
# Note: this function is new
192188
@requires_api_version('2023.12')
193-
def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]:
189+
def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
194190
if not (-x.ndim <= axis < x.ndim):
195191
raise ValueError("axis out of range")
196192

‎array_api_strict/_searching_functions.py‎

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
4-
from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
5-
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags
6-
from ._helpers import _maybe_normalize_py_scalars
7-
8-
from typing import TYPE_CHECKING
9-
if TYPE_CHECKING:
10-
from typing import Literal, Optional, Tuple, Union
3+
from typing import Literal
114

125
import numpy as np
136

7+
from ._array_object import Array
8+
from ._dtypes import _real_numeric_dtypes, _result_type
9+
from ._dtypes import bool as _bool
10+
from ._flags import requires_api_version, requires_data_dependent_shapes
11+
from ._helpers import _maybe_normalize_py_scalars
12+
1413

15-
def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
14+
def argmax(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
1615
"""
1716
Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`.
1817
@@ -23,7 +22,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
2322
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device)
2423

2524

26-
def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
25+
def argmin(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
2726
"""
2827
Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`.
2928
@@ -35,7 +34,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
3534

3635

3736
@requires_data_dependent_shapes
38-
def nonzero(x: Array, /) -> Tuple[Array, ...]:
37+
def nonzero(x: Array, /) -> tuple[Array, ...]:
3938
"""
4039
Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.
4140
@@ -52,7 +51,7 @@ def count_nonzero(
5251
x: Array,
5352
/,
5453
*,
55-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
54+
axis: int | tuple[int, ...] | None = None,
5655
keepdims: bool = False,
5756
) -> Array:
5857
"""
@@ -71,7 +70,7 @@ def searchsorted(
7170
/,
7271
*,
7372
side: Literal["left", "right"] = "left",
74-
sorter: Optional[Array] = None,
73+
sorter: Array | None = None,
7574
) -> Array:
7675
"""
7776
Array API compatible wrapper for :py:func:`np.searchsorted <numpy.searchsorted>`.
@@ -84,25 +83,29 @@ def searchsorted(
8483
if x1.device != x2.device:
8584
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
8685

87-
sorter = sorter._array if sorter is not None else None
86+
np_sorter = sorter._array if sorter is not None else None
8887
# TODO: The sort order of nans and signed zeros is implementation
8988
# dependent. Should we error/warn if they are present?
9089

9190
# x1 must be 1-D, but NumPy already requires this.
92-
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
91+
return Array._new(
92+
np.searchsorted(x1._array, x2._array, side=side, sorter=np_sorter),
93+
device=x1.device,
94+
)
95+
9396

9497
def where(
9598
condition: Array,
96-
x1: bool | int | float | complex | Array,
97-
x2: bool | int | float | complex | Array, /
99+
x1: Array | bool | int | float | complex,
100+
x2: Array | bool | int | float | complex,
101+
/,
98102
) -> Array:
99103
"""
100104
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
101105
102106
See its docstring for more information.
103107
"""
104-
if get_array_api_strict_flags()['api_version'] > '2023.12':
105-
x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where")
108+
x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where")
106109

107110
# Call result type here just to raise on disallowed type combinations
108111
_result_type(x1.dtype, x2.dtype)

‎array_api_strict/_set_functions.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
4-
5-
from ._flags import requires_data_dependent_shapes
6-
73
from typing import NamedTuple
84

95
import numpy as np
106

7+
from ._array_object import Array
8+
from ._flags import requires_data_dependent_shapes
9+
1110
# Note: np.unique() is split into four functions in the array API:
1211
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
1312
# to remove polymorphic return types).
@@ -20,6 +19,7 @@
2019
# Note: The functions here return a namedtuple (np.unique() returns a normal
2120
# tuple).
2221

22+
2323
class UniqueAllResult(NamedTuple):
2424
values: Array
2525
indices: Array

‎array_api_strict/_sorting_functions.py‎

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
4-
from ._dtypes import _real_numeric_dtypes
3+
from typing import Literal
54

65
import numpy as np
76

7+
from ._array_object import Array
8+
from ._dtypes import _real_numeric_dtypes
9+
810

911
# Note: the descending keyword argument is new in this function
1012
def argsort(
@@ -18,7 +20,7 @@ def argsort(
1820
if x.dtype not in _real_numeric_dtypes:
1921
raise TypeError("Only real numeric dtypes are allowed in argsort")
2022
# Note: this keyword argument is different, and the default is different.
21-
kind = "stable" if stable else "quicksort"
23+
kind: Literal["stable", "quicksort"] = "stable" if stable else "quicksort"
2224
if not descending:
2325
res = np.argsort(x._array, axis=axis, kind=kind)
2426
else:
@@ -35,6 +37,7 @@ def argsort(
3537
res = max_i - res
3638
return Array._new(res, device=x.device)
3739

40+
3841
# Note: the descending keyword argument is new in this function
3942
def sort(
4043
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
@@ -47,8 +50,7 @@ def sort(
4750
if x.dtype not in _real_numeric_dtypes:
4851
raise TypeError("Only real numeric dtypes are allowed in sort")
4952
# Note: this keyword argument is different, and the default is different.
50-
kind = "stable" if stable else "quicksort"
51-
res = np.sort(x._array, axis=axis, kind=kind)
53+
res = np.sort(x._array, axis=axis, kind="stable" if stable else "quicksort")
5254
if descending:
5355
res = np.flip(res, axis=axis)
5456
return Array._new(res, device=x.device)

‎array_api_strict/_statistical_functions.py‎

Lines changed: 58 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,36 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
5+
import numpy as np
6+
7+
from ._array_object import Array
8+
from ._creation_functions import ones, zeros
39
from ._dtypes import (
4-
_real_floating_dtypes,
5-
_real_numeric_dtypes,
10+
DType,
611
_floating_dtypes,
12+
_np_dtype,
713
_numeric_dtypes,
14+
_real_floating_dtypes,
15+
_real_numeric_dtypes,
16+
complex64,
17+
float32,
818
)
9-
from ._array_object import Array
10-
from ._dtypes import float32, complex64
11-
from ._flags import requires_api_version, get_array_api_strict_flags
12-
from ._creation_functions import zeros, ones
19+
from ._flags import get_array_api_strict_flags, requires_api_version
1320
from ._manipulation_functions import concat
1421

15-
from typing import TYPE_CHECKING
16-
17-
if TYPE_CHECKING:
18-
from typing import Optional, Tuple, Union
19-
from ._typing import Dtype
20-
21-
import numpy as np
2222

2323
@requires_api_version('2023.12')
2424
def cumulative_sum(
2525
x: Array,
2626
/,
2727
*,
28-
axis: Optional[int] = None,
29-
dtype: Optional[Dtype] = None,
28+
axis: int | None = None,
29+
dtype: DType | None = None,
3030
include_initial: bool = False,
3131
) -> Array:
3232
if x.dtype not in _numeric_dtypes:
3333
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
34-
if dtype is not None:
35-
dtype = dtype._np_dtype
3634

3735
# TODO: The standard is not clear about what should happen when x.ndim == 0.
3836
if axis is None:
@@ -44,26 +42,23 @@ def cumulative_sum(
4442
if axis < 0:
4543
axis += x.ndim
4644
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
47-
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
45+
return Array._new(np.cumsum(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
4846

4947

5048
@requires_api_version('2024.12')
5149
def cumulative_prod(
5250
x: Array,
5351
/,
5452
*,
55-
axis: Optional[int] = None,
56-
dtype: Optional[Dtype] = None,
53+
axis: int | None = None,
54+
dtype: DType | None = None,
5755
include_initial: bool = False,
5856
) -> Array:
5957
if x.dtype not in _numeric_dtypes:
6058
raise TypeError("Only numeric dtypes are allowed in cumulative_prod")
6159
if x.ndim == 0:
6260
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")
6361

64-
if dtype is not None:
65-
dtype = dtype._np_dtype
66-
6762
if axis is None:
6863
if x.ndim > 1:
6964
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
@@ -74,14 +69,14 @@ def cumulative_prod(
7469
if axis < 0:
7570
axis += x.ndim
7671
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
77-
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device)
72+
return Array._new(np.cumprod(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
7873

7974

8075
def max(
8176
x: Array,
8277
/,
8378
*,
84-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
79+
axis: int | tuple[int, ...] | None = None,
8580
keepdims: bool = False,
8681
) -> Array:
8782
if x.dtype not in _real_numeric_dtypes:
@@ -93,14 +88,15 @@ def mean(
9388
x: Array,
9489
/,
9590
*,
96-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
91+
axis: int | tuple[int, ...] | None = None,
9792
keepdims: bool = False,
9893
) -> Array:
9994

100-
if get_array_api_strict_flags()['api_version'] > '2023.12':
101-
allowed_dtypes = _floating_dtypes
102-
else:
103-
allowed_dtypes = _real_floating_dtypes
95+
allowed_dtypes = (
96+
_floating_dtypes
97+
if get_array_api_strict_flags()['api_version'] > '2023.12'
98+
else _real_floating_dtypes
99+
)
104100

105101
if x.dtype not in allowed_dtypes:
106102
raise TypeError("Only floating-point dtypes are allowed in mean")
@@ -111,45 +107,51 @@ def min(
111107
x: Array,
112108
/,
113109
*,
114-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
110+
axis: int | tuple[int, ...] | None = None,
115111
keepdims: bool = False,
116112
) -> Array:
117113
if x.dtype not in _real_numeric_dtypes:
118114
raise TypeError("Only real numeric dtypes are allowed in min")
119115
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device)
120116

121117

118+
def _np_dtype_sumprod(x: Array, dtype: DType | None) -> np.dtype[Any] | None:
119+
"""In versions prior to 2023.12, sum() and prod() upcast for all
120+
dtypes when dtype=None. For 2023.12, the behavior is the same as in
121+
NumPy (only upcast for integral dtypes).
122+
"""
123+
if dtype is None and get_array_api_strict_flags()['api_version'] < '2023.12':
124+
if x.dtype == float32:
125+
return np.float64 # type: ignore[return-value]
126+
elif x.dtype == complex64:
127+
return np.complex128 # type: ignore[return-value]
128+
return _np_dtype(dtype)
129+
130+
122131
def prod(
123132
x: Array,
124133
/,
125134
*,
126-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
127-
dtype: Optional[Dtype] = None,
135+
axis: int | tuple[int, ...] | None = None,
136+
dtype: DType | None = None,
128137
keepdims: bool = False,
129138
) -> Array:
130139
if x.dtype not in _numeric_dtypes:
131140
raise TypeError("Only numeric dtypes are allowed in prod")
132141

133-
if dtype is None:
134-
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
135-
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
136-
# NumPy (only upcast for integral dtypes).
137-
if get_array_api_strict_flags()['api_version'] < '2023.12':
138-
if x.dtype == float32:
139-
dtype = np.float64
140-
elif x.dtype == complex64:
141-
dtype = np.complex128
142-
else:
143-
dtype = dtype._np_dtype
144-
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device)
142+
np_dtype = _np_dtype_sumprod(x, dtype)
143+
return Array._new(
144+
np.prod(x._array, dtype=np_dtype, axis=axis, keepdims=keepdims),
145+
device=x.device,
146+
)
145147

146148

147149
def std(
148150
x: Array,
149151
/,
150152
*,
151-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
152-
correction: Union[int, float] = 0.0,
153+
axis: int | tuple[int, ...] | None = None,
154+
correction: int | float = 0.0,
153155
keepdims: bool = False,
154156
) -> Array:
155157
# Note: the keyword argument correction is different here
@@ -162,33 +164,26 @@ def sum(
162164
x: Array,
163165
/,
164166
*,
165-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
166-
dtype: Optional[Dtype] = None,
167+
axis: int | tuple[int, ...] | None = None,
168+
dtype: DType | None = None,
167169
keepdims: bool = False,
168170
) -> Array:
169171
if x.dtype not in _numeric_dtypes:
170172
raise TypeError("Only numeric dtypes are allowed in sum")
171173

172-
if dtype is None:
173-
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
174-
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
175-
# NumPy (only upcast for integral dtypes).
176-
if get_array_api_strict_flags()['api_version'] < '2023.12':
177-
if x.dtype == float32:
178-
dtype = np.float64
179-
elif x.dtype == complex64:
180-
dtype = np.complex128
181-
else:
182-
dtype = dtype._np_dtype
183-
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device)
174+
np_dtype = _np_dtype_sumprod(x, dtype)
175+
return Array._new(
176+
np.sum(x._array, axis=axis, dtype=np_dtype, keepdims=keepdims),
177+
device=x.device,
178+
)
184179

185180

186181
def var(
187182
x: Array,
188183
/,
189184
*,
190-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
191-
correction: Union[int, float] = 0.0,
185+
axis: int | tuple[int, ...] | None = None,
186+
correction: int | float = 0.0,
192187
keepdims: bool = False,
193188
) -> Array:
194189
# Note: the keyword argument correction is different here

‎array_api_strict/_typing.py‎

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,82 +8,62 @@
88

99
from __future__ import annotations
1010

11-
__all__ = [
12-
"Array",
13-
"Device",
14-
"Dtype",
15-
"SupportsDLPack",
16-
"SupportsBufferProtocol",
17-
"PyCapsule",
18-
]
19-
2011
import sys
12+
from typing import Any, Protocol, TypedDict, TypeVar
2113

22-
from typing import (
23-
Any,
24-
TypedDict,
25-
TypeVar,
26-
Protocol,
27-
)
28-
29-
from ._array_object import Array, _device
30-
from ._dtypes import _DType
31-
from ._info import __array_namespace_info__
14+
from ._dtypes import DType
3215

3316
_T_co = TypeVar("_T_co", covariant=True)
3417

18+
3519
class NestedSequence(Protocol[_T_co]):
3620
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
3721
def __len__(self, /) -> int: ...
3822

3923

40-
Device = _device
41-
42-
Dtype = _DType
43-
44-
Info = __array_namespace_info__
45-
4624
if sys.version_info >= (3, 12):
4725
from collections.abc import Buffer as SupportsBufferProtocol
4826
else:
4927
SupportsBufferProtocol = Any
5028

5129
PyCapsule = Any
5230

31+
5332
class SupportsDLPack(Protocol):
5433
def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
5534

35+
5636
Capabilities = TypedDict(
57-
"Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool,
58-
"max dimensions": int}
37+
"Capabilities",
38+
{
39+
"boolean indexing": bool,
40+
"data-dependent shapes": bool,
41+
"max dimensions": int,
42+
},
5943
)
6044

6145
DefaultDataTypes = TypedDict(
6246
"DefaultDataTypes",
6347
{
64-
"real floating": Dtype,
65-
"complex floating": Dtype,
66-
"integral": Dtype,
67-
"indexing": Dtype,
48+
"real floating": DType,
49+
"complex floating": DType,
50+
"integral": DType,
51+
"indexing": DType,
6852
},
6953
)
7054

71-
DataTypes = TypedDict(
72-
"DataTypes",
73-
{
74-
"bool": Dtype,
75-
"float32": Dtype,
76-
"float64": Dtype,
77-
"complex64": Dtype,
78-
"complex128": Dtype,
79-
"int8": Dtype,
80-
"int16": Dtype,
81-
"int32": Dtype,
82-
"int64": Dtype,
83-
"uint8": Dtype,
84-
"uint16": Dtype,
85-
"uint32": Dtype,
86-
"uint64": Dtype,
87-
},
88-
total=False,
89-
)
55+
56+
class DataTypes(TypedDict, total=False):
57+
bool: DType
58+
float32: DType
59+
float64: DType
60+
complex64: DType
61+
complex128: DType
62+
int8: DType
63+
int16: DType
64+
int32: DType
65+
int64: DType
66+
uint8: DType
67+
uint16: DType
68+
uint32: DType
69+
uint64: DType

‎array_api_strict/_utility_functions.py‎

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
4-
from ._flags import requires_api_version
5-
from ._dtypes import _numeric_dtypes
6-
7-
from typing import TYPE_CHECKING
8-
if TYPE_CHECKING:
9-
from typing import Optional, Tuple, Union
3+
from typing import Any
104

115
import numpy as np
6+
import numpy.typing as npt
7+
8+
from ._array_object import Array
9+
from ._dtypes import _numeric_dtypes
10+
from ._flags import requires_api_version
1211

1312

1413
def all(
1514
x: Array,
1615
/,
1716
*,
18-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
17+
axis: int | tuple[int, ...] | None = None,
1918
keepdims: bool = False,
2019
) -> Array:
2120
"""
@@ -30,7 +29,7 @@ def any(
3029
x: Array,
3130
/,
3231
*,
33-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
32+
axis: int | tuple[int, ...] | None = None,
3433
keepdims: bool = False,
3534
) -> Array:
3635
"""
@@ -40,15 +39,16 @@ def any(
4039
"""
4140
return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device)
4241

42+
4343
@requires_api_version('2024.12')
4444
def diff(
4545
x: Array,
4646
/,
4747
*,
4848
axis: int = -1,
4949
n: int = 1,
50-
prepend: Optional[Array] = None,
51-
append: Optional[Array] = None,
50+
prepend: Array | None = None,
51+
append: Array | None = None,
5252
) -> Array:
5353
if x.dtype not in _numeric_dtypes:
5454
raise TypeError("Only numeric dtypes are allowed in diff")
@@ -57,7 +57,7 @@ def diff(
5757
# currently specified.
5858

5959
# NumPy does not support prepend=None or append=None
60-
kwargs = dict(axis=axis, n=n)
60+
kwargs: dict[str, int | npt.NDArray[Any]] = {"axis": axis, "n": n}
6161
if prepend is not None:
6262
if prepend.device != x.device:
6363
raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.")

‎array_api_strict/py.typed‎

Whitespace-only changes.

‎array_api_strict/tests/test_validation.py‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Callable
2-
31
import pytest
42

53
import array_api_strict as xp
64

75

8-
def p(func: Callable, *args, **kwargs):
6+
def p(func, *args, **kwargs):
97
f_sig = ", ".join(
108
[str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()]
119
)

‎pyproject.toml‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ Repository = "https://github.com/data-apis/array-api-strict"
3131
[tool.setuptools_scm]
3232
version_file = "array_api_strict/_version.py"
3333

34+
[tool.mypy]
35+
disallow_incomplete_defs = true
36+
disallow_untyped_decorators = true
37+
disallow_untyped_defs = true
38+
no_implicit_optional = true
39+
show_error_codes = true
40+
warn_redundant_casts = true
41+
warn_unused_ignores = true
42+
warn_unreachable = true
43+
strict_bytes = true
44+
local_partial_types = true
45+
46+
[[tool.mypy.overrides]]
47+
module = ["*.tests.*"]
48+
disallow_untyped_defs = false

0 commit comments

Comments
 (0)
Please sign in to comment.