Skip to content

Commit 3bc51bd

Browse files
committed
Add GroupBy.shuffle()
1 parent e2981d3 commit 3bc51bd

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

xarray/core/duck_array_ops.py

+15
Original file line numberDiff line numberDiff line change
@@ -831,3 +831,18 @@ def chunked_nanfirst(darray, axis):
831831

832832
def chunked_nanlast(darray, axis):
833833
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)
834+
835+
836+
def shuffle_array(array, indices: list[list[int]], axis: int):
837+
# TODO: do chunk manager dance here.
838+
if is_duck_dask_array(array):
839+
if not module_available("dask", minversion="2024.08.0"):
840+
raise ValueError(
841+
"This method is very inefficient on dask<2024.08.0. Please upgrade."
842+
)
843+
# TODO: handle dimensions
844+
return array.shuffle(indexer=indices, axis=axis)
845+
else:
846+
indexer = np.concatenate(indices)
847+
# TODO: Do the array API thing here.
848+
return np.take(array, indices=indexer, axis=axis)

xarray/core/groupby.py

+48
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,54 @@ def sizes(self) -> Mapping[Hashable, int]:
517517
self._sizes = self._obj.isel({self._group_dim: index}).sizes
518518
return self._sizes
519519

520+
def shuffle(self) -> None:
521+
"""
522+
Shuffle the underlying object so that all members in a group occur sequentially.
523+
524+
The order of appearance is not guaranteed. This method modifies the underlying Xarray
525+
object in place.
526+
527+
Use this method first if you need to map a function that requires all members of a group
528+
be in a single chunk.
529+
"""
530+
from xarray.core.dataarray import DataArray
531+
from xarray.core.dataset import Dataset
532+
from xarray.core.duck_array_ops import shuffle_array
533+
534+
(grouper,) = self.groupers
535+
dim = self._group_dim
536+
537+
# Slices mean this is already sorted. E.g. resampling ops, _DummyGroup
538+
if all(isinstance(idx, slice) for idx in self._group_indices):
539+
return
540+
541+
was_array = isinstance(self._obj, DataArray)
542+
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
543+
544+
shuffled = Dataset()
545+
for name, var in as_dataset._variables.items():
546+
if dim not in var.dims:
547+
shuffled[name] = var
548+
continue
549+
shuffled_data = shuffle_array(
550+
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
551+
)
552+
shuffled[name] = var._replace(data=shuffled_data)
553+
554+
# Replace self._group_indices with slices
555+
slices = []
556+
start = 0
557+
for idxr in self._group_indices:
558+
slices.append(slice(start, start + len(idxr)))
559+
start += len(idxr)
560+
# TODO: we have now broken the invariant
561+
# self._group_indices ≠ self.groupers[0].group_indices
562+
self._group_indices = tuple(slices)
563+
if was_array:
564+
self._obj = self._obj._from_temp_dataset(shuffled)
565+
else:
566+
self._obj = shuffled
567+
520568
def map(
521569
self,
522570
func: Callable,

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _importorskip(
106106
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
107107
has_cftime, requires_cftime = _importorskip("cftime")
108108
has_dask, requires_dask = _importorskip("dask")
109+
has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0")
109110
with warnings.catch_warnings():
110111
warnings.filterwarnings(
111112
"ignore",

xarray/tests/test_groupby.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
assert_identical,
2222
create_test_data,
2323
has_cftime,
24+
has_dask_ge_2024_08_0,
2425
has_flox,
2526
requires_cftime,
2627
requires_dask,
@@ -1293,11 +1294,26 @@ def test_groupby_sum(self) -> None:
12931294
assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y"))
12941295
assert_allclose(expected_sum_axis1, grouped.sum("y"))
12951296

1297+
@pytest.mark.parametrize(
1298+
"shuffle",
1299+
[
1300+
pytest.param(
1301+
True,
1302+
marks=pytest.mark.skipif(
1303+
not has_dask_ge_2024_08_0, reason="dask too old"
1304+
),
1305+
),
1306+
False,
1307+
],
1308+
)
12961309
@pytest.mark.parametrize("method", ["sum", "mean", "median"])
1297-
def test_groupby_reductions(self, method) -> None:
1310+
def test_groupby_reductions(self, method: str, shuffle: bool) -> None:
12981311
array = self.da
12991312
grouped = array.groupby("abc")
13001313

1314+
if shuffle:
1315+
grouped.shuffle()
1316+
13011317
reduction = getattr(np, method)
13021318
expected = Dataset(
13031319
{

0 commit comments

Comments
 (0)