Skip to content

Commit 830fa00

Browse files
committed
add type hints
1 parent 55e3bff commit 830fa00

File tree

1 file changed

+99
-44
lines changed

1 file changed

+99
-44
lines changed

pandas/core/missing.py

Lines changed: 99 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
Routines for filling missing data.
33
"""
44
from functools import partial
5-
from typing import Any, List, Optional, Set, Union
5+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
66

77
import numpy as np
88

99
from pandas._libs import algos, lib
10-
from pandas._typing import ArrayLike, Axis, DtypeObj
10+
from pandas._typing import ArrayLike, Axis, DtypeObj, IndexLabel, Scalar
1111
from pandas.compat._optional import import_optional_dependency
1212

1313
from pandas.core.dtypes.cast import infer_dtype_from_array
@@ -20,7 +20,9 @@
2020
from pandas.core.dtypes.missing import isna
2121

2222

23-
def mask_missing(arr: ArrayLike, values_to_mask) -> np.ndarray:
23+
def mask_missing(
24+
arr: ArrayLike, values_to_mask: Union[List, Tuple, Scalar]
25+
) -> np.ndarray:
2426
"""
2527
Return a masking array of same size/shape as arr
2628
with entries equaling any member of values_to_mask set to True
@@ -58,7 +60,7 @@ def mask_missing(arr: ArrayLike, values_to_mask) -> np.ndarray:
5860
return mask
5961

6062

61-
def clean_fill_method(method, allow_nearest: bool = False):
63+
def clean_fill_method(method: str, allow_nearest: bool = False) -> Optional[str]:
6264
# asfreq is compat for resampling
6365
if method in [None, "asfreq"]:
6466
return None
@@ -117,7 +119,7 @@ def clean_interp_method(method: str, **kwargs) -> str:
117119
return method
118120

119121

120-
def find_valid_index(values, how: str):
122+
def find_valid_index(values: ArrayLike, how: str) -> Optional[int]:
121123
"""
122124
Retrieves the index of the first valid value.
123125
@@ -165,7 +167,7 @@ def interpolate_1d(
165167
bounds_error: bool = False,
166168
order: Optional[int] = None,
167169
**kwargs,
168-
):
170+
) -> np.ndarray:
169171
"""
170172
Logic for the 1-d interpolation. The result should be 1-d, inputs
171173
xvalues and yvalues will each be 1-d arrays of the same length.
@@ -218,8 +220,13 @@ def interpolate_1d(
218220

219221
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
220222
all_nans = set(np.flatnonzero(invalid))
221-
start_nans = set(range(find_valid_index(yvalues, "first")))
222-
end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid)))
223+
224+
start_nan_idx = find_valid_index(yvalues, "first")
225+
start_nans = set() if start_nan_idx is None else set(range(start_nan_idx))
226+
227+
end_nan_idx = find_valid_index(yvalues, "last")
228+
end_nans = set() if end_nan_idx is None else set(range(1 + end_nan_idx, len(valid)))
229+
223230
mid_nans = all_nans - start_nans - end_nans
224231

225232
# Like the sets above, preserve_nans contains indices of invalid values,
@@ -294,8 +301,15 @@ def interpolate_1d(
294301

295302

296303
def _interpolate_scipy_wrapper(
297-
x, y, new_x, method, fill_value=None, bounds_error=False, order=None, **kwargs
298-
):
304+
x,
305+
y,
306+
new_x,
307+
method: Optional[str],
308+
fill_value: Optional[Scalar] = None,
309+
bounds_error: bool = False,
310+
order: Optional[int] = None,
311+
**kwargs,
312+
) -> np.ndarray:
299313
"""
300314
Passed off to scipy.interpolate.interp1d. method is scipy's kind.
301315
Returns an array interpolated at new_x. Add any new methods to
@@ -326,7 +340,7 @@ def _interpolate_scipy_wrapper(
326340
elif method == "cubicspline":
327341
alt_methods["cubicspline"] = _cubicspline_interpolate
328342

329-
interp1d_methods = [
343+
interp1d_methods: List[str] = [
330344
"nearest",
331345
"zero",
332346
"slinear",
@@ -335,18 +349,18 @@ def _interpolate_scipy_wrapper(
335349
"polynomial",
336350
]
337351
if method in interp1d_methods:
338-
if method == "polynomial":
339-
method = order
352+
kind = order if method == "polynomial" else method
340353
terp = interpolate.interp1d(
341-
x, y, kind=method, fill_value=fill_value, bounds_error=bounds_error
354+
x, y, kind=kind, fill_value=fill_value, bounds_error=bounds_error
342355
)
343356
new_y = terp(new_x)
344357
elif method == "spline":
345358
# GH #10633, #24014
346-
if isna(order) or (order <= 0):
347-
raise ValueError(
348-
f"order needs to be specified and greater than 0; got order: {order}"
349-
)
359+
if isna(order):
360+
raise ValueError(f"order needs to be specified; got order: {order}")
361+
assert isinstance(order, int)
362+
if order <= 0:
363+
raise ValueError(f"order needs to be greater than 0; got order: {order}")
350364
terp = interpolate.UnivariateSpline(x, y, k=order, **kwargs)
351365
new_y = terp(new_x)
352366
else:
@@ -358,12 +372,21 @@ def _interpolate_scipy_wrapper(
358372
y = y.copy()
359373
if not new_x.flags.writeable:
360374
new_x = new_x.copy()
361-
method = alt_methods[method]
362-
new_y = method(x, y, new_x, **kwargs)
375+
376+
assert isinstance(method, str)
377+
alt_method = alt_methods[method]
378+
new_y = alt_method(x, y, new_x, **kwargs)
363379
return new_y
364380

365381

366-
def _from_derivatives(xi, yi, x, order=None, der=0, extrapolate=False):
382+
def _from_derivatives(
383+
xi: np.ndarray,
384+
yi: np.ndarray,
385+
x: Union[Scalar, ArrayLike],
386+
order: Optional[Union[int, List[int]]] = None,
387+
der: Union[int, List[int]] = 0,
388+
extrapolate: bool = False,
389+
) -> np.ndarray:
367390
"""
368391
Convenience function for interpolate.BPoly.from_derivatives.
369392
@@ -406,7 +429,13 @@ def _from_derivatives(xi, yi, x, order=None, der=0, extrapolate=False):
406429
return m(x)
407430

408431

409-
def _akima_interpolate(xi, yi, x, der=0, axis=0):
432+
def _akima_interpolate(
433+
xi: ArrayLike,
434+
yi: ArrayLike,
435+
x: Union[Scalar, ArrayLike],
436+
der: Optional[int] = 0,
437+
axis: Optional[int] = 0,
438+
) -> Union[Scalar, ArrayLike]:
410439
"""
411440
Convenience function for akima interpolation.
412441
xi and yi are arrays of values used to approximate some function f,
@@ -449,7 +478,14 @@ def _akima_interpolate(xi, yi, x, der=0, axis=0):
449478
return P(x, nu=der)
450479

451480

452-
def _cubicspline_interpolate(xi, yi, x, axis=0, bc_type="not-a-knot", extrapolate=None):
481+
def _cubicspline_interpolate(
482+
xi: ArrayLike,
483+
yi: ArrayLike,
484+
x: Union[ArrayLike, Scalar],
485+
axis: Optional[int] = 0,
486+
bc_type: Union[str, Tuple] = "not-a-knot",
487+
extrapolate: Optional[Union[bool, str]] = None,
488+
) -> Union[ArrayLike, Scalar]:
453489
"""
454490
Convenience function for cubic spline data interpolator.
455491
@@ -557,6 +593,8 @@ def _interpolate_with_limit_area(
557593
first = find_valid_index(values, "first")
558594
last = find_valid_index(values, "last")
559595

596+
assert first is not None and last is not None
597+
560598
values = interpolate_2d(
561599
values,
562600
method=method,
@@ -574,12 +612,12 @@ def _interpolate_with_limit_area(
574612

575613

576614
def interpolate_2d(
577-
values,
615+
values: np.ndarray,
578616
method: str = "pad",
579617
axis: Axis = 0,
580618
limit: Optional[int] = None,
581619
limit_area: Optional[str] = None,
582-
):
620+
) -> np.ndarray:
583621
"""
584622
Perform an actual interpolation of values, values will be make 2-d if
585623
needed fills inplace, returns the result.
@@ -625,7 +663,10 @@ def interpolate_2d(
625663
raise AssertionError("cannot interpolate on a ndim == 1 with axis != 0")
626664
values = values.reshape(tuple((1,) + values.shape))
627665

628-
method = clean_fill_method(method)
666+
method_cleaned = clean_fill_method(method)
667+
assert isinstance(method_cleaned, str)
668+
method = method_cleaned
669+
629670
tvalues = transf(values)
630671
if method == "pad":
631672
result = _pad_2d(tvalues, limit=limit)
@@ -644,7 +685,9 @@ def interpolate_2d(
644685
return result
645686

646687

647-
def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool):
688+
def _cast_values_for_fillna(
689+
values: ArrayLike, dtype: DtypeObj, has_mask: bool
690+
) -> ArrayLike:
648691
"""
649692
Cast values to a dtype that algos.pad and algos.backfill can handle.
650693
"""
@@ -663,34 +706,41 @@ def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool):
663706
return values
664707

665708

666-
def _fillna_prep(values, mask=None):
709+
def _fillna_prep(
710+
values: np.ndarray, mask: Optional[np.ndarray] = None
711+
) -> Tuple[np.ndarray, np.ndarray]:
667712
# boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d
668-
dtype = values.dtype
669713

670714
has_mask = mask is not None
671-
if not has_mask:
672-
# This needs to occur before datetime/timedeltas are cast to int64
673-
mask = isna(values)
674715

675-
values = _cast_values_for_fillna(values, dtype, has_mask)
716+
# This needs to occur before datetime/timedeltas are cast to int64
717+
mask = isna(values) if mask is None else mask
676718

719+
values = _cast_values_for_fillna(values, values.dtype, has_mask)
677720
mask = mask.view(np.uint8)
721+
678722
return values, mask
679723

680724

681-
def _pad_1d(values, limit=None, mask=None):
725+
def _pad_1d(
726+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
727+
) -> np.ndarray:
682728
values, mask = _fillna_prep(values, mask)
683729
algos.pad_inplace(values, mask, limit=limit)
684730
return values
685731

686732

687-
def _backfill_1d(values, limit=None, mask=None):
733+
def _backfill_1d(
734+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
735+
) -> np.ndarray:
688736
values, mask = _fillna_prep(values, mask)
689737
algos.backfill_inplace(values, mask, limit=limit)
690738
return values
691739

692740

693-
def _pad_2d(values, limit=None, mask=None):
741+
def _pad_2d(
742+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
743+
) -> np.ndarray:
694744
values, mask = _fillna_prep(values, mask)
695745

696746
if np.all(values.shape):
@@ -701,7 +751,9 @@ def _pad_2d(values, limit=None, mask=None):
701751
return values
702752

703753

704-
def _backfill_2d(values, limit=None, mask=None):
754+
def _backfill_2d(
755+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
756+
) -> np.ndarray:
705757
values, mask = _fillna_prep(values, mask)
706758

707759
if np.all(values.shape):
@@ -715,16 +767,19 @@ def _backfill_2d(values, limit=None, mask=None):
715767
_fill_methods = {"pad": _pad_1d, "backfill": _backfill_1d}
716768

717769

718-
def get_fill_func(method):
719-
method = clean_fill_method(method)
720-
return _fill_methods[method]
770+
def get_fill_func(method: str) -> Callable:
771+
method_cleaned = clean_fill_method(method)
772+
assert isinstance(method_cleaned, str)
773+
return _fill_methods[method_cleaned]
721774

722775

723-
def clean_reindex_fill_method(method):
776+
def clean_reindex_fill_method(method: str) -> Optional[str]:
724777
return clean_fill_method(method, allow_nearest=True)
725778

726779

727-
def _interp_limit(invalid, fw_limit, bw_limit):
780+
def _interp_limit(
781+
invalid: np.ndarray, fw_limit: Optional[int], bw_limit: Optional[int]
782+
) -> Set[IndexLabel]:
728783
"""
729784
Get indexers of values that won't be filled
730785
because they exceed the limits.
@@ -759,7 +814,7 @@ def _interp_limit(invalid, fw_limit, bw_limit):
759814
f_idx = set()
760815
b_idx = set()
761816

762-
def inner(invalid, limit):
817+
def inner(invalid: np.ndarray, limit: int) -> Set[IndexLabel]:
763818
limit = min(limit, N)
764819
windowed = _rolling_window(invalid, limit + 1).all(1)
765820
idx = set(np.where(windowed)[0] + limit) | set(
@@ -789,7 +844,7 @@ def inner(invalid, limit):
789844
return f_idx & b_idx
790845

791846

792-
def _rolling_window(a: np.ndarray, window: int):
847+
def _rolling_window(a: np.ndarray, window: int) -> np.ndarray:
793848
"""
794849
[True, True, False, True, False], 2 ->
795850

0 commit comments

Comments
 (0)