1010from xarray import DataArray , Dataset , set_options
1111from xarray .tests import (
1212 assert_allclose ,
13- assert_array_equal ,
1413 assert_equal ,
1514 assert_identical ,
1615 has_dask ,
2423]
2524
2625
26+ @pytest .fixture (params = ["numbagg" , "bottleneck" ])
27+ def compute_backend (request ):
28+ if request .param == "bottleneck" :
29+ options = dict (use_bottleneck = True , use_numbagg = False )
30+ elif request .param == "numbagg" :
31+ options = dict (use_bottleneck = False , use_numbagg = True )
32+ else :
33+ raise ValueError
34+
35+ with xr .set_options (** options ):
36+ yield request .param
37+
38+
2739class TestDataArrayRolling :
2840 @pytest .mark .parametrize ("da" , (1 , 2 ), indirect = True )
2941 @pytest .mark .parametrize ("center" , [True , False ])
@@ -87,9 +99,10 @@ def test_rolling_properties(self, da) -> None:
8799 @pytest .mark .parametrize ("center" , (True , False , None ))
88100 @pytest .mark .parametrize ("min_periods" , (1 , None ))
89101 @pytest .mark .parametrize ("backend" , ["numpy" ], indirect = True )
90- def test_rolling_wrapped_bottleneck (self , da , name , center , min_periods ) -> None :
102+ def test_rolling_wrapped_bottleneck (
103+ self , da , name , center , min_periods , compute_backend
104+ ) -> None :
91105 bn = pytest .importorskip ("bottleneck" , minversion = "1.1" )
92-
93106 # Test all bottleneck functions
94107 rolling_obj = da .rolling (time = 7 , min_periods = min_periods )
95108
@@ -98,15 +111,18 @@ def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None
98111 expected = getattr (bn , func_name )(
99112 da .values , window = 7 , axis = 1 , min_count = min_periods
100113 )
101- assert_array_equal (actual .values , expected )
114+
115+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
116+ np .testing .assert_allclose (actual .values , expected )
102117
103118 with pytest .warns (DeprecationWarning , match = "Reductions are applied" ):
104119 getattr (rolling_obj , name )(dim = "time" )
105120
106121 # Test center
107122 rolling_obj = da .rolling (time = 7 , center = center )
108123 actual = getattr (rolling_obj , name )()["time" ]
109- assert_equal (actual , da ["time" ])
124+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
125+ assert_allclose (actual , da ["time" ])
110126
111127 @requires_dask
112128 @pytest .mark .parametrize ("name" , ("mean" , "count" ))
@@ -153,7 +169,9 @@ def test_rolling_wrapped_dask_nochunk(self, center) -> None:
153169 @pytest .mark .parametrize ("center" , (True , False ))
154170 @pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
155171 @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
156- def test_rolling_pandas_compat (self , center , window , min_periods ) -> None :
172+ def test_rolling_pandas_compat (
173+ self , center , window , min_periods , compute_backend
174+ ) -> None :
157175 s = pd .Series (np .arange (10 ))
158176 da = DataArray .from_series (s )
159177
@@ -203,7 +221,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None:
203221 @pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
204222 @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
205223 @pytest .mark .parametrize ("name" , ("sum" , "mean" , "std" , "max" ))
206- def test_rolling_reduce (self , da , center , min_periods , window , name ) -> None :
224+ def test_rolling_reduce (
225+ self , da , center , min_periods , window , name , compute_backend
226+ ) -> None :
207227 if min_periods is not None and window < min_periods :
208228 min_periods = window
209229
@@ -223,7 +243,9 @@ def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
223243 @pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
224244 @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
225245 @pytest .mark .parametrize ("name" , ("sum" , "max" ))
226- def test_rolling_reduce_nonnumeric (self , center , min_periods , window , name ) -> None :
246+ def test_rolling_reduce_nonnumeric (
247+ self , center , min_periods , window , name , compute_backend
248+ ) -> None :
227249 da = DataArray (
228250 [0 , np .nan , 1 , 2 , np .nan , 3 , 4 , 5 , np .nan , 6 , 7 ], dims = "time"
229251 ).isnull ()
@@ -239,7 +261,7 @@ def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> N
239261 assert_allclose (actual , expected )
240262 assert actual .dims == expected .dims
241263
242- def test_rolling_count_correct (self ) -> None :
264+ def test_rolling_count_correct (self , compute_backend ) -> None :
243265 da = DataArray ([0 , np .nan , 1 , 2 , np .nan , 3 , 4 , 5 , np .nan , 6 , 7 ], dims = "time" )
244266
245267 kwargs : list [dict [str , Any ]] = [
@@ -279,7 +301,9 @@ def test_rolling_count_correct(self) -> None:
279301 @pytest .mark .parametrize ("center" , (True , False ))
280302 @pytest .mark .parametrize ("min_periods" , (None , 1 ))
281303 @pytest .mark .parametrize ("name" , ("sum" , "mean" , "max" ))
282- def test_ndrolling_reduce (self , da , center , min_periods , name ) -> None :
304+ def test_ndrolling_reduce (
305+ self , da , center , min_periods , name , compute_backend
306+ ) -> None :
283307 rolling_obj = da .rolling (time = 3 , x = 2 , center = center , min_periods = min_periods )
284308
285309 actual = getattr (rolling_obj , name )()
@@ -560,7 +584,7 @@ def test_rolling_properties(self, ds) -> None:
560584 @pytest .mark .parametrize ("key" , ("z1" , "z2" ))
561585 @pytest .mark .parametrize ("backend" , ["numpy" ], indirect = True )
562586 def test_rolling_wrapped_bottleneck (
563- self , ds , name , center , min_periods , key
587+ self , ds , name , center , min_periods , key , compute_backend
564588 ) -> None :
565589 bn = pytest .importorskip ("bottleneck" , minversion = "1.1" )
566590
@@ -577,12 +601,12 @@ def test_rolling_wrapped_bottleneck(
577601 )
578602 else :
579603 raise ValueError
580- assert_array_equal (actual [key ].values , expected )
604+ np . testing . assert_allclose (actual [key ].values , expected )
581605
582606 # Test center
583607 rolling_obj = ds .rolling (time = 7 , center = center )
584608 actual = getattr (rolling_obj , name )()["time" ]
585- assert_equal (actual , ds ["time" ])
609+ assert_allclose (actual , ds ["time" ])
586610
587611 @pytest .mark .parametrize ("center" , (True , False ))
588612 @pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
0 commit comments