|
55 | 55 | align,
|
56 | 56 | )
|
57 | 57 | from xarray.core.arithmetic import DatasetArithmetic
|
| 58 | +from xarray.core.array_api_compat import to_like_array |
58 | 59 | from xarray.core.common import (
|
59 | 60 | DataWithCoords,
|
60 | 61 | _contains_datetime_like_objects,
|
|
127 | 128 | calculate_dimensions,
|
128 | 129 | )
|
129 | 130 | from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
|
130 |
| -from xarray.namedarray.pycompat import array_type, is_chunked_array |
| 131 | +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy |
131 | 132 | from xarray.plot.accessor import DatasetPlotAccessor
|
132 | 133 | from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
|
133 | 134 |
|
@@ -6620,7 +6621,7 @@ def dropna(
|
6620 | 6621 | array = self._variables[k]
|
6621 | 6622 | if dim in array.dims:
|
6622 | 6623 | dims = [d for d in array.dims if d != dim]
|
6623 |
| - count += np.asarray(array.count(dims)) |
| 6624 | + count += to_numpy(array.count(dims).data) |
6624 | 6625 | size += math.prod([self.sizes[d] for d in dims])
|
6625 | 6626 |
|
6626 | 6627 | if thresh is not None:
|
@@ -8734,16 +8735,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False):
|
8734 | 8735 | coord_names.add(k)
|
8735 | 8736 | else:
|
8736 | 8737 | if k in self.data_vars and dim in v.dims:
|
| 8738 | + coord_data = to_like_array(coord_var.data, like=v.data) |
8737 | 8739 | if _contains_datetime_like_objects(v):
|
8738 | 8740 | v = datetime_to_numeric(v, datetime_unit=datetime_unit)
|
8739 | 8741 | if cumulative:
|
8740 | 8742 | integ = duck_array_ops.cumulative_trapezoid(
|
8741 |
| - v.data, coord_var.data, axis=v.get_axis_num(dim) |
| 8743 | + v.data, coord_data, axis=v.get_axis_num(dim) |
8742 | 8744 | )
|
8743 | 8745 | v_dims = v.dims
|
8744 | 8746 | else:
|
8745 | 8747 | integ = duck_array_ops.trapz(
|
8746 |
| - v.data, coord_var.data, axis=v.get_axis_num(dim) |
| 8748 | + v.data, coord_data, axis=v.get_axis_num(dim) |
8747 | 8749 | )
|
8748 | 8750 | v_dims = list(v.dims)
|
8749 | 8751 | v_dims.remove(dim)
|
|
0 commit comments