Skip to content

Commit 98a5905

Browse files
committed
🚚 port ma.arg{min,max} and MaskedArray.arg{min,max}
1 parent b37e58b commit 98a5905

File tree

1 file changed

+156
-19
lines changed

1 file changed

+156
-19
lines changed

‎src/numpy-stubs/ma/core.pyi

Lines changed: 156 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri
44

55
import numpy as np
66
from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd
7-
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003
8-
from numpy._typing import _BoolCodes
7+
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003
8+
from numpy._globals import _NoValueType
9+
from numpy._typing import _BoolCodes, _ScalarLike_co
910

1011
__all__ = [
1112
"MAError",
@@ -188,6 +189,12 @@ __all__ = [
188189
"zeros_like",
189190
]
190191

192+
###
193+
194+
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
195+
196+
###
197+
191198
_UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True)
192199
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
193200
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], default=tuple[int, ...], covariant=True)
@@ -650,15 +657,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
650657
fill_value: Incomplete = ...,
651658
keepdims: Incomplete = ...,
652659
) -> Incomplete: ...
653-
@override
654-
def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
660+
661+
# Keep in-sync with np.ma.argmin
662+
@overload # type: ignore[override]
663+
def argmin(
655664
self,
656-
axis: Incomplete = ...,
657-
fill_value: Incomplete = ...,
658-
out: Incomplete = ...,
665+
axis: None = None,
666+
fill_value: _ScalarLike_co | None = None,
667+
out: None = None,
659668
*,
660-
keepdims: Incomplete = ...,
661-
) -> Incomplete: ...
669+
keepdims: L[False] | _NoValueType = ...,
670+
) -> intp: ...
671+
@overload
672+
def argmin(
673+
self,
674+
axis: CanIndex | None = None,
675+
fill_value: _ScalarLike_co | None = None,
676+
out: None = None,
677+
*,
678+
keepdims: bool | _NoValueType = ...,
679+
) -> Any: ...
680+
@overload
681+
def argmin(
682+
self,
683+
axis: CanIndex | None = None,
684+
fill_value: _ScalarLike_co | None = None,
685+
*,
686+
out: _ArrayT,
687+
keepdims: bool | _NoValueType = ...,
688+
) -> _ArrayT: ...
689+
@overload
690+
def argmin( # pyright: ignore[reportIncompatibleMethodOverride]
691+
self,
692+
axis: CanIndex | None,
693+
fill_value: _ScalarLike_co | None,
694+
out: _ArrayT,
695+
*,
696+
keepdims: bool | _NoValueType = ...,
697+
) -> _ArrayT: ...
662698

663699
#
664700
@override
@@ -669,15 +705,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
669705
fill_value: Incomplete = ...,
670706
keepdims: Incomplete = ...,
671707
) -> Incomplete: ...
672-
@override
673-
def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
708+
709+
# Keep in-sync with np.ma.argmax
710+
@overload # type: ignore[override]
711+
def argmax(
674712
self,
675-
axis: Incomplete = ...,
676-
fill_value: Incomplete = ...,
677-
out: Incomplete = ...,
713+
axis: None = None,
714+
fill_value: _ScalarLike_co | None = None,
715+
out: None = None,
678716
*,
679-
keepdims: Incomplete = ...,
680-
) -> Incomplete: ...
717+
keepdims: L[False] | _NoValueType = ...,
718+
) -> intp: ...
719+
@overload
720+
def argmax(
721+
self,
722+
axis: CanIndex | None = None,
723+
fill_value: _ScalarLike_co | None = None,
724+
out: None = None,
725+
*,
726+
keepdims: bool | _NoValueType = ...,
727+
) -> Any: ...
728+
@overload
729+
def argmax(
730+
self,
731+
axis: CanIndex | None = None,
732+
fill_value: _ScalarLike_co | None = None,
733+
*,
734+
out: _ArrayT,
735+
keepdims: bool | _NoValueType = ...,
736+
) -> _ArrayT: ...
737+
@overload
738+
def argmax( # pyright: ignore[reportIncompatibleMethodOverride]
739+
self,
740+
axis: CanIndex | None,
741+
fill_value: _ScalarLike_co | None,
742+
out: _ArrayT,
743+
*,
744+
keepdims: bool | _NoValueType = ...,
745+
) -> _ArrayT: ...
681746

682747
#
683748
@override
@@ -1066,8 +1131,80 @@ swapaxes: _frommethod
10661131
trace: _frommethod
10671132
var: _frommethod
10681133
count: _frommethod
1069-
argmin: _frommethod
1070-
argmax: _frommethod
1071-
10721134
minimum: _extrema_operation
10731135
maximum: _extrema_operation
1136+
1137+
@overload
1138+
def argmin(
1139+
self: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1140+
axis: None = None,
1141+
fill_value: _ScalarLike_co | None = None,
1142+
out: None = None,
1143+
*,
1144+
keepdims: L[False] | _NoValueType = ...,
1145+
) -> intp: ...
1146+
@overload
1147+
def argmin(
1148+
self: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1149+
axis: CanIndex | None = None,
1150+
fill_value: _ScalarLike_co | None = None,
1151+
out: None = None,
1152+
*,
1153+
keepdims: bool | _NoValueType = ...,
1154+
) -> Any: ...
1155+
@overload
1156+
def argmin(
1157+
self: _ArrayT,
1158+
axis: CanIndex | None = None,
1159+
fill_value: _ScalarLike_co | None = None,
1160+
*,
1161+
out: _ArrayT,
1162+
keepdims: bool | _NoValueType = ...,
1163+
) -> _ArrayT: ...
1164+
@overload
1165+
def argmin(
1166+
self: _ArrayT,
1167+
axis: CanIndex | None,
1168+
fill_value: _ScalarLike_co | None,
1169+
out: _ArrayT,
1170+
*,
1171+
keepdims: bool | _NoValueType = ...,
1172+
) -> _ArrayT: ...
1173+
1174+
#
1175+
@overload
1176+
def argmax(
1177+
self: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1178+
axis: None = None,
1179+
fill_value: _ScalarLike_co | None = None,
1180+
out: None = None,
1181+
*,
1182+
keepdims: L[False] | _NoValueType = ...,
1183+
) -> intp: ...
1184+
@overload
1185+
def argmax(
1186+
self: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1187+
axis: CanIndex | None = None,
1188+
fill_value: _ScalarLike_co | None = None,
1189+
out: None = None,
1190+
*,
1191+
keepdims: bool | _NoValueType = ...,
1192+
) -> Any: ...
1193+
@overload
1194+
def argmax(
1195+
self: _ArrayT,
1196+
axis: CanIndex | None = None,
1197+
fill_value: _ScalarLike_co | None = None,
1198+
*,
1199+
out: _ArrayT,
1200+
keepdims: bool | _NoValueType = ...,
1201+
) -> _ArrayT: ...
1202+
@overload
1203+
def argmax(
1204+
self: _ArrayT,
1205+
axis: CanIndex | None,
1206+
fill_value: _ScalarLike_co | None,
1207+
out: _ArrayT,
1208+
*,
1209+
keepdims: bool | _NoValueType = ...,
1210+
) -> _ArrayT: ...

0 commit comments

Comments
 (0)