5
5
from collections .abc import Callable , Hashable , Sequence
6
6
from functools import partial
7
7
from numbers import Number
8
- from typing import TYPE_CHECKING , Any , get_args
8
+ from typing import TYPE_CHECKING , Any , Optional , get_args
9
9
10
10
import numpy as np
11
11
import pandas as pd
29
29
if TYPE_CHECKING :
30
30
from xarray .core .dataarray import DataArray
31
31
from xarray .core .dataset import Dataset
32
+ from xarray .core .variable import IndexVariable
32
33
33
34
34
35
def _get_nan_block_lengths (
35
- obj : Dataset | DataArray | Variable , dim : Hashable , index : Variable
36
- ):
36
+ obj : Dataset | DataArray , dim : Hashable , index : Variable
37
+ ) -> Any :
37
38
"""
38
39
Return an object where each NaN element in 'obj' is replaced by the
39
40
length of the gap the element is in.
@@ -66,12 +67,12 @@ class BaseInterpolator:
66
67
cons_kwargs : dict [str , Any ]
67
68
call_kwargs : dict [str , Any ]
68
69
f : Callable
69
- method : str
70
+ method : str | int
70
71
71
- def __call__ (self , x ) :
72
+ def __call__ (self , x : np . ndarray ) -> np . ndarray :
72
73
return self .f (x , ** self .call_kwargs )
73
74
74
- def __repr__ (self ):
75
+ def __repr__ (self ) -> str :
75
76
return f"{ self .__class__ .__name__ } : method={ self .method } "
76
77
77
78
@@ -83,7 +84,14 @@ class NumpyInterpolator(BaseInterpolator):
83
84
numpy.interp
84
85
"""
85
86
86
- def __init__ (self , xi , yi , method = "linear" , fill_value = None , period = None ):
87
+ def __init__ (
88
+ self ,
89
+ xi : Variable ,
90
+ yi : np .ndarray ,
91
+ method : Optional [str ] = "linear" ,
92
+ fill_value = None ,
93
+ period = None ,
94
+ ):
87
95
if method != "linear" :
88
96
raise ValueError ("only method `linear` is valid for the NumpyInterpolator" )
89
97
@@ -104,8 +112,8 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
104
112
self ._left = fill_value [0 ]
105
113
self ._right = fill_value [1 ]
106
114
elif is_scalar (fill_value ):
107
- self ._left = fill_value
108
- self ._right = fill_value
115
+ self ._left = fill_value # type: ignore[assignment]
116
+ self ._right = fill_value # type: ignore[assignment]
109
117
else :
110
118
raise ValueError (f"{ fill_value } is not a valid fill_value" )
111
119
@@ -130,14 +138,14 @@ class ScipyInterpolator(BaseInterpolator):
130
138
131
139
def __init__ (
132
140
self ,
133
- xi ,
134
- yi ,
135
- method = None ,
136
- fill_value = None ,
137
- assume_sorted = True ,
138
- copy = False ,
139
- bounds_error = False ,
140
- order = None ,
141
+ xi : Variable ,
142
+ yi : np . ndarray ,
143
+ method : Optional [ str | int ] = None ,
144
+ fill_value : Optional [ float | complex ] = None ,
145
+ assume_sorted : bool = True ,
146
+ copy : bool = False ,
147
+ bounds_error : bool = False ,
148
+ order : Optional [ int ] = None ,
141
149
axis = - 1 ,
142
150
** kwargs ,
143
151
):
@@ -154,18 +162,13 @@ def __init__(
154
162
raise ValueError ("order is required when method=polynomial" )
155
163
method = order
156
164
157
- self .method = method
165
+ self .method : str | int = method
158
166
159
167
self .cons_kwargs = kwargs
160
168
self .call_kwargs = {}
161
169
162
170
nan = np .nan if yi .dtype .kind != "c" else np .nan + np .nan * 1j
163
171
164
- if fill_value is None and method == "linear" :
165
- fill_value = nan , nan
166
- elif fill_value is None :
167
- fill_value = nan
168
-
169
172
self .f = interp1d (
170
173
xi ,
171
174
yi ,
@@ -601,7 +604,12 @@ def _floatize_x(x, new_x):
601
604
return x , new_x
602
605
603
606
604
- def interp (var , indexes_coords , method : InterpOptions , ** kwargs ):
607
+ def interp (
608
+ var : Variable ,
609
+ indexes_coords : dict [str , IndexVariable ],
610
+ method : InterpOptions ,
611
+ ** kwargs ,
612
+ ) -> Variable :
605
613
"""Make an interpolation of Variable
606
614
607
615
Parameters
@@ -662,7 +670,13 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
662
670
return result
663
671
664
672
665
- def interp_func (var , x , new_x , method : InterpOptions , kwargs ):
673
+ def interp_func (
674
+ var : np .ndarray ,
675
+ x : list [IndexVariable ],
676
+ new_x : list [IndexVariable ],
677
+ method : InterpOptions ,
678
+ kwargs : dict ,
679
+ ) -> np .ndarray :
666
680
"""
667
681
multi-dimensional interpolation for array-like. Interpolated axes should be
668
682
located in the last position.
@@ -766,9 +780,14 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
766
780
return _interpnd (var , x , new_x , func , kwargs )
767
781
768
782
769
- def _interp1d (var , x , new_x , func , kwargs ):
783
+ def _interp1d (
784
+ var : np .ndarray ,
785
+ x : IndexVariable ,
786
+ new_x : IndexVariable ,
787
+ func : Callable ,
788
+ kwargs : dict ,
789
+ ) -> np .ndarray :
770
790
# x, new_x are tuples of size 1.
771
- x , new_x = x [0 ], new_x [0 ]
772
791
rslt = func (x , var , ** kwargs )(np .ravel (new_x ))
773
792
if new_x .ndim > 1 :
774
793
return reshape (rslt , (var .shape [:- 1 ] + new_x .shape ))
@@ -777,11 +796,17 @@ def _interp1d(var, x, new_x, func, kwargs):
777
796
return rslt
778
797
779
798
780
- def _interpnd (var , x , new_x , func , kwargs ):
799
+ def _interpnd (
800
+ var : np .ndarray ,
801
+ x : list [IndexVariable ],
802
+ new_x : list [IndexVariable ],
803
+ func : Callable ,
804
+ kwargs : dict ,
805
+ ) -> np .ndarray :
781
806
x , new_x = _floatize_x (x , new_x )
782
807
783
808
if len (x ) == 1 :
784
- return _interp1d (var , x , new_x , func , kwargs )
809
+ return _interp1d (var , x [ 0 ] , new_x [ 0 ] , func , kwargs )
785
810
786
811
# move the interpolation axes to the start position
787
812
var = var .transpose (range (- len (x ), var .ndim - len (x )))
0 commit comments