@@ -27,6 +27,8 @@ from typing import (
27
27
)
28
28
from typing_extensions import Buffer, CapsuleType, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, override
29
29
30
+ import numpy as np
31
+
30
32
from . import (
31
33
__config__ as __config__,
32
34
_array_api_info as _array_api_info,
@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
611
613
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
612
614
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
613
615
616
+ _Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]]
617
+
614
618
###
615
619
# Type Aliases (for internal use only)
616
620
@@ -2531,8 +2535,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2531
2535
@overload
2532
2536
def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
2533
2537
2534
- # TODO(jorenham): Support the "1d @ 1d -> scalar" case
2535
- # https://github.com/numpy/numtype/issues/197
2538
+ @overload
2539
+ def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
2536
2540
@overload
2537
2541
def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
2538
2542
@overload
@@ -2566,12 +2570,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2566
2570
@overload
2567
2571
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
2568
2572
@overload
2569
- def __matmul__(self: NDArray[object_], rhs: object , /) -> NDArray[object_]: ...
2573
+ def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co , /) -> NDArray[object_]: ...
2570
2574
@overload
2571
2575
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
2572
2576
2573
2577
# keep in sync with __matmul__
2574
2578
@overload
2579
+ def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
2580
+ @overload
2575
2581
def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
2576
2582
@overload
2577
2583
def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
@@ -2604,7 +2610,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2604
2610
@overload
2605
2611
def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
2606
2612
@overload
2607
- def __rmatmul__(self: NDArray[object_], lhs: object , /) -> NDArray[object_]: ...
2613
+ def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co , /) -> NDArray[object_]: ...
2608
2614
@overload
2609
2615
def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
2610
2616
0 commit comments