Skip to content

Commit c3e7a48

Browse files
committed
initial typing
1 parent 5c64182 commit c3e7a48

File tree

1 file changed

+54
-29
lines changed

1 file changed

+54
-29
lines changed

xarray/core/missing.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable, Hashable, Sequence
66
from functools import partial
77
from numbers import Number
8-
from typing import TYPE_CHECKING, Any, get_args
8+
from typing import TYPE_CHECKING, Any, Optional, get_args
99

1010
import numpy as np
1111
import pandas as pd
@@ -29,11 +29,12 @@
2929
if TYPE_CHECKING:
3030
from xarray.core.dataarray import DataArray
3131
from xarray.core.dataset import Dataset
32+
from xarray.core.variable import IndexVariable
3233

3334

3435
def _get_nan_block_lengths(
35-
obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
36-
):
36+
obj: Dataset | DataArray, dim: Hashable, index: Variable
37+
) -> Any:
3738
"""
3839
Return an object where each NaN element in 'obj' is replaced by the
3940
length of the gap the element is in.
@@ -66,12 +67,12 @@ class BaseInterpolator:
6667
cons_kwargs: dict[str, Any]
6768
call_kwargs: dict[str, Any]
6869
f: Callable
69-
method: str
70+
method: str | int
7071

71-
def __call__(self, x):
72+
def __call__(self, x: np.ndarray) -> np.ndarray:
7273
return self.f(x, **self.call_kwargs)
7374

74-
def __repr__(self):
75+
def __repr__(self) -> str:
7576
return f"{self.__class__.__name__}: method={self.method}"
7677

7778

@@ -83,7 +84,14 @@ class NumpyInterpolator(BaseInterpolator):
8384
numpy.interp
8485
"""
8586

86-
def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
87+
def __init__(
88+
self,
89+
xi: Variable,
90+
yi: np.ndarray,
91+
method: Optional[str] = "linear",
92+
fill_value=None,
93+
period=None,
94+
):
8795
if method != "linear":
8896
raise ValueError("only method `linear` is valid for the NumpyInterpolator")
8997

@@ -104,8 +112,8 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
104112
self._left = fill_value[0]
105113
self._right = fill_value[1]
106114
elif is_scalar(fill_value):
107-
self._left = fill_value
108-
self._right = fill_value
115+
self._left = fill_value # type: ignore[assignment]
116+
self._right = fill_value # type: ignore[assignment]
109117
else:
110118
raise ValueError(f"{fill_value} is not a valid fill_value")
111119

@@ -130,14 +138,14 @@ class ScipyInterpolator(BaseInterpolator):
130138

131139
def __init__(
132140
self,
133-
xi,
134-
yi,
135-
method=None,
136-
fill_value=None,
137-
assume_sorted=True,
138-
copy=False,
139-
bounds_error=False,
140-
order=None,
141+
xi: Variable,
142+
yi: np.ndarray,
143+
method: Optional[str | int] = None,
144+
fill_value: Optional[float | complex] = None,
145+
assume_sorted: bool = True,
146+
copy: bool = False,
147+
bounds_error: bool = False,
148+
order: Optional[int] = None,
141149
axis=-1,
142150
**kwargs,
143151
):
@@ -154,18 +162,13 @@ def __init__(
154162
raise ValueError("order is required when method=polynomial")
155163
method = order
156164

157-
self.method = method
165+
self.method: str | int = method
158166

159167
self.cons_kwargs = kwargs
160168
self.call_kwargs = {}
161169

162170
nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j
163171

164-
if fill_value is None and method == "linear":
165-
fill_value = nan, nan
166-
elif fill_value is None:
167-
fill_value = nan
168-
169172
self.f = interp1d(
170173
xi,
171174
yi,
@@ -601,7 +604,12 @@ def _floatize_x(x, new_x):
601604
return x, new_x
602605

603606

604-
def interp(var, indexes_coords, method: InterpOptions, **kwargs):
607+
def interp(
608+
var: Variable,
609+
indexes_coords: dict[str, IndexVariable],
610+
method: InterpOptions,
611+
**kwargs,
612+
) -> Variable:
605613
"""Make an interpolation of Variable
606614
607615
Parameters
@@ -662,7 +670,13 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
662670
return result
663671

664672

665-
def interp_func(var, x, new_x, method: InterpOptions, kwargs):
673+
def interp_func(
674+
var: np.ndarray,
675+
x: list[IndexVariable],
676+
new_x: list[IndexVariable],
677+
method: InterpOptions,
678+
kwargs: dict,
679+
) -> np.ndarray:
666680
"""
667681
multi-dimensional interpolation for array-like. Interpolated axes should be
668682
located in the last position.
@@ -766,9 +780,14 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
766780
return _interpnd(var, x, new_x, func, kwargs)
767781

768782

769-
def _interp1d(var, x, new_x, func, kwargs):
783+
def _interp1d(
784+
var: np.ndarray,
785+
x: IndexVariable,
786+
new_x: IndexVariable,
787+
func: Callable,
788+
kwargs: dict,
789+
) -> np.ndarray:
770790
# x, new_x are tuples of size 1.
771-
x, new_x = x[0], new_x[0]
772791
rslt = func(x, var, **kwargs)(np.ravel(new_x))
773792
if new_x.ndim > 1:
774793
return reshape(rslt, (var.shape[:-1] + new_x.shape))
@@ -777,11 +796,17 @@ def _interp1d(var, x, new_x, func, kwargs):
777796
return rslt
778797

779798

780-
def _interpnd(var, x, new_x, func, kwargs):
799+
def _interpnd(
800+
var: np.ndarray,
801+
x: list[IndexVariable],
802+
new_x: list[IndexVariable],
803+
func: Callable,
804+
kwargs: dict,
805+
) -> np.ndarray:
781806
x, new_x = _floatize_x(x, new_x)
782807

783808
if len(x) == 1:
784-
return _interp1d(var, x, new_x, func, kwargs)
809+
return _interp1d(var, x[0], new_x[0], func, kwargs)
785810

786811
# move the interpolation axes to the start position
787812
var = var.transpose(range(-len(x), var.ndim - len(x)))

0 commit comments

Comments
 (0)