Skip to content

Commit b3cd8eb

Browse files
committed
Duck array ops for all and any
1 parent 49502fc commit b3cd8eb

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

xarray/core/duck_array_ops.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import numpy as np
1818
import pandas as pd
19-
from numpy import all as array_all # noqa: F401
20-
from numpy import any as array_any # noqa: F401
2119
from numpy import ( # noqa: F401
2220
isclose,
2321
isnat,
@@ -319,7 +317,7 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
319317
if lazy_equiv is None:
320318
with warnings.catch_warnings():
321319
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
322-
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
320+
return bool(array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True)))
323321
else:
324322
return lazy_equiv
325323

@@ -333,7 +331,7 @@ def array_equiv(arr1, arr2):
333331
with warnings.catch_warnings():
334332
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
335333
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
336-
return bool(flag_array.all())
334+
return bool(array_all(flag_array))
337335
else:
338336
return lazy_equiv
339337

@@ -349,7 +347,7 @@ def array_notnull_equiv(arr1, arr2):
349347
with warnings.catch_warnings():
350348
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
351349
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
352-
return bool(flag_array.all())
350+
return bool(array_all(flag_array))
353351
else:
354352
return lazy_equiv
355353

@@ -536,6 +534,16 @@ def f(values, axis=None, skipna=None, **kwargs):
536534
cumsum_1d.numeric_only = True
537535

538536

537+
def array_all(array, axis=None, keepdims=False):
538+
xp = get_array_namespace(array)
539+
return xp.all(array, axis=axis, keepdims=keepdims)
540+
541+
542+
def array_any(array, axis=None, keepdims=False):
543+
xp = get_array_namespace(array)
544+
return xp.any(array, axis=axis, keepdims=keepdims)
545+
546+
539547
_mean = _create_nan_agg_method("mean", invariant_0d=True)
540548

541549

xarray/core/weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
171171

172172
def _weight_check(w):
173173
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
174-
if duck_array_ops.isnull(w).any():
174+
if duck_array_ops.array_any(duck_array_ops.isnull(w)):
175175
raise ValueError(
176176
"`weights` cannot contain missing values. "
177177
"Missing values can be replaced by `weights.fillna(0)`."

0 commit comments

Comments
 (0)