1
1
from __future__ import annotations
2
2
3
+ from collections .abc import Sequence
3
4
from functools import partial
5
+ from typing import Literal , NamedTuple
4
6
5
- from ._dtypes import (
6
- _floating_dtypes ,
7
- _numeric_dtypes ,
8
- float32 ,
9
- complex64 ,
10
- complex128 ,
11
- )
7
+ import numpy as np
8
+ import numpy .linalg
9
+
10
+ from ._array_object import Array
12
11
from ._data_type_functions import finfo
13
- from ._manipulation_functions import reshape
12
+ from ._dtypes import DType , _floating_dtypes , _numeric_dtypes , complex64 , complex128
14
13
from ._elementwise_functions import conj
15
- from ._array_object import Array
16
- from ._flags import requires_extension , get_array_api_strict_flags
14
+ from ._flags import get_array_api_strict_flags , requires_extension
15
+ from ._manipulation_functions import reshape
16
+ from ._statistical_functions import _np_dtype_sumprod
17
17
18
18
try :
19
- from numpy ._core .numeric import normalize_axis_tuple
19
+ from numpy ._core .numeric import normalize_axis_tuple # type: ignore[attr-defined]
20
20
except ImportError :
21
- from numpy .core .numeric import normalize_axis_tuple
21
+ from numpy .core .numeric import normalize_axis_tuple # type: ignore[no-redef]
22
22
23
- from typing import TYPE_CHECKING
24
- if TYPE_CHECKING :
25
- from ._typing import Literal , Optional , Sequence , Tuple , Union , Dtype
26
-
27
- from typing import NamedTuple
28
-
29
- import numpy .linalg
30
- import numpy as np
31
23
32
24
class EighResult (NamedTuple ):
33
25
eigenvalues : Array
@@ -175,7 +167,13 @@ def inv(x: Array, /) -> Array:
175
167
# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
176
168
# literals.
177
169
@requires_extension ('linalg' )
178
- def matrix_norm (x : Array , / , * , keepdims : bool = False , ord : Optional [Union [int , float , Literal ['fro' , 'nuc' ]]] = 'fro' ) -> Array : # noqa: F821
170
+ def matrix_norm (
171
+ x : Array ,
172
+ / ,
173
+ * ,
174
+ keepdims : bool = False ,
175
+ ord : float | Literal ["fro" , "nuc" ] | None = "fro" ,
176
+ ) -> Array : # noqa: F821
179
177
"""
180
178
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
181
179
@@ -186,7 +184,10 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int,
186
184
if x .dtype not in _floating_dtypes :
187
185
raise TypeError ('Only floating-point dtypes are allowed in matrix_norm' )
188
186
189
- return Array ._new (np .linalg .norm (x ._array , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord ), device = x .device )
187
+ return Array ._new (
188
+ np .linalg .norm (x ._array , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord ),
189
+ device = x .device ,
190
+ )
190
191
191
192
192
193
@requires_extension ('linalg' )
@@ -206,7 +207,7 @@ def matrix_power(x: Array, n: int, /) -> Array:
206
207
207
208
# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
208
209
@requires_extension ('linalg' )
209
- def matrix_rank (x : Array , / , * , rtol : Optional [ Union [ float , Array ]] = None ) -> Array :
210
+ def matrix_rank (x : Array , / , * , rtol : float | Array | None = None ) -> Array :
210
211
"""
211
212
Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
212
213
@@ -218,13 +219,12 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A
218
219
raise np .linalg .LinAlgError ("1-dimensional array given. Array must be at least two-dimensional" )
219
220
S = np .linalg .svd (x ._array , compute_uv = False )
220
221
if rtol is None :
221
- tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * finfo (S .dtype ).eps
222
+ tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * np . finfo (S .dtype ).eps
222
223
else :
223
- if isinstance (rtol , Array ):
224
- rtol = rtol ._array
224
+ rtol_np = rtol ._array if isinstance (rtol , Array ) else np .asarray (rtol )
225
225
# Note: this is different from np.linalg.matrix_rank, which does not multiply
226
226
# the tolerance by the largest singular value.
227
- tol = S .max (axis = - 1 , keepdims = True )* np . asarray ( rtol ) [..., np .newaxis ]
227
+ tol = S .max (axis = - 1 , keepdims = True ) * rtol_np [..., np .newaxis ]
228
228
return Array ._new (np .count_nonzero (S > tol , axis = - 1 ), device = x .device )
229
229
230
230
@@ -252,7 +252,7 @@ def outer(x1: Array, x2: Array, /) -> Array:
252
252
253
253
# Note: the keyword argument name rtol is different from np.linalg.pinv
254
254
@requires_extension ('linalg' )
255
- def pinv (x : Array , / , * , rtol : Optional [ Union [ float , Array ]] = None ) -> Array :
255
+ def pinv (x : Array , / , * , rtol : float | Array | None = None ) -> Array :
256
256
"""
257
257
Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
258
258
@@ -267,9 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
267
267
# default tolerance by max(M, N).
268
268
if rtol is None :
269
269
rtol = max (x .shape [- 2 :]) * finfo (x .dtype ).eps
270
- if isinstance (rtol , Array ):
271
- rtol = rtol ._array
272
- return Array ._new (np .linalg .pinv (x ._array , rcond = rtol ), device = x .device )
270
+ rtol_np = rtol ._array if isinstance (rtol , Array ) else rtol
271
+ return Array ._new (np .linalg .pinv (x ._array , rcond = rtol_np ), device = x .device )
273
272
274
273
@requires_extension ('linalg' )
275
274
def qr (x : Array , / , * , mode : Literal ['reduced' , 'complete' ] = 'reduced' ) -> QRResult : # noqa: F821
@@ -312,14 +311,14 @@ def slogdet(x: Array, /) -> SlogdetResult:
312
311
313
312
# To workaround this, the below is the code from np.linalg.solve except
314
313
# only calling solve1 in the exactly 1D case.
315
- def _solve (a , b ) :
314
+ def _solve (a : np . ndarray , b : np . ndarray ) -> np . ndarray :
316
315
try :
317
- from numpy .linalg ._linalg import (
316
+ from numpy .linalg ._linalg import ( # type: ignore[attr-defined]
318
317
_makearray , _assert_stacked_2d , _assert_stacked_square ,
319
318
_commonType , isComplexType , _raise_linalgerror_singular
320
319
)
321
320
except ImportError :
322
- from numpy .linalg .linalg import (
321
+ from numpy .linalg .linalg import ( # type: ignore[attr-defined]
323
322
_makearray , _assert_stacked_2d , _assert_stacked_square ,
324
323
_commonType , isComplexType , _raise_linalgerror_singular
325
324
)
@@ -382,14 +381,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
382
381
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
383
382
# np.linalg.svd(compute_uv=False).
384
383
@requires_extension ('linalg' )
385
- def svdvals (x : Array , / ) -> Union [ Array , Tuple [ Array , ...]] :
384
+ def svdvals (x : Array , / ) -> Array :
386
385
if x .dtype not in _floating_dtypes :
387
386
raise TypeError ('Only floating-point dtypes are allowed in svdvals' )
388
387
return Array ._new (np .linalg .svd (x ._array , compute_uv = False ), device = x .device )
389
388
390
389
# Note: trace is the numpy top-level namespace, not np.linalg
391
390
@requires_extension ('linalg' )
392
- def trace (x : Array , / , * , offset : int = 0 , dtype : Optional [ Dtype ] = None ) -> Array :
391
+ def trace (x : Array , / , * , offset : int = 0 , dtype : DType | None = None ) -> Array :
393
392
"""
394
393
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
395
394
@@ -398,27 +397,28 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr
398
397
if x .dtype not in _numeric_dtypes :
399
398
raise TypeError ('Only numeric dtypes are allowed in trace' )
400
399
401
- # Note: trace() works the same as sum() and prod() (see
402
- # _statistical_functions.py)
403
- if dtype is None :
404
- if get_array_api_strict_flags ()['api_version' ] < '2023.12' :
405
- if x .dtype == float32 :
406
- dtype = np .float64
407
- elif x .dtype == complex64 :
408
- dtype = np .complex128
409
- else :
410
- dtype = dtype ._np_dtype
400
+ # Note: trace() works the same as sum() and prod() (see _statistical_functions.py)
401
+ np_dtype = _np_dtype_sumprod (x , dtype )
402
+
411
403
# Note: trace always operates on the last two axes, whereas np.trace
412
404
# operates on the first two axes by default
413
- return Array ._new (np .asarray (np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = dtype )), device = x .device )
405
+ res = np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = np_dtype )
406
+ return Array ._new (np .asarray (res ), device = x .device )
414
407
415
408
# Note: the name here is different from norm(). The array API norm is split
416
409
# into matrix_norm and vector_norm().
417
410
418
411
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
419
412
# -np.inf]]] but Literal does not support floating-point literals.
420
413
@requires_extension ('linalg' )
421
- def vector_norm (x : Array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , keepdims : bool = False , ord : Optional [Union [int , float ]] = 2 ) -> Array :
414
+ def vector_norm (
415
+ x : Array ,
416
+ / ,
417
+ * ,
418
+ axis : int | tuple [int , ...] | None = None ,
419
+ keepdims : bool = False ,
420
+ ord : int | float = 2 ,
421
+ ) -> Array :
422
422
"""
423
423
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
424
424
@@ -456,8 +456,8 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
456
456
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
457
457
# above to avoid matrix norm logic.
458
458
shape = list (x .shape )
459
- _axis = normalize_axis_tuple (range (x .ndim ) if axis is None else axis , x .ndim )
460
- for i in _axis :
459
+ axis_tup = normalize_axis_tuple (range (x .ndim ) if axis is None else axis , x .ndim )
460
+ for i in axis_tup :
461
461
shape [i ] = 1
462
462
res = reshape (res , tuple (shape ))
463
463
@@ -480,7 +480,13 @@ def matmul(x1: Array, x2: Array, /) -> Array:
480
480
481
481
# Note: tensordot is the numpy top-level namespace but not in np.linalg
482
482
@requires_extension ('linalg' )
483
- def tensordot (x1 : Array , x2 : Array , / , * , axes : Union [int , Tuple [Sequence [int ], Sequence [int ]]] = 2 ) -> Array :
483
+ def tensordot (
484
+ x1 : Array ,
485
+ x2 : Array ,
486
+ / ,
487
+ * ,
488
+ axes : int | tuple [Sequence [int ], Sequence [int ]] = 2 ,
489
+ ) -> Array :
484
490
from ._linear_algebra_functions import tensordot
485
491
return tensordot (x1 , x2 , axes = axes )
486
492
0 commit comments