Skip to content
forked from pydata/xarray

Commit a15b04d

Browse files
committed
Handle multiple groupers
1 parent dfdc96a commit a15b04d

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

xarray/core/groupby.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
2323
from xarray.core.concat import concat
2424
from xarray.core.coordinates import Coordinates
25+
from xarray.core.duck_array_ops import where
2526
from xarray.core.formatting import format_array_flat
2627
from xarray.core.indexes import (
2728
PandasIndex,
@@ -462,20 +463,26 @@ def factorize(self) -> EncodedGroups:
462463
# NaNs; as well as values outside the bins are coded by -1
463464
# Restore these after the raveling
464465
mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type]
465-
_flatcodes[mask] = -1
466-
467-
midx = pd.MultiIndex.from_product(
468-
(grouper.unique_coord.data for grouper in groupers),
469-
names=tuple(grouper.name for grouper in groupers),
470-
)
471-
# Constructing an index from the product is wrong when there are missing groups
472-
# (e.g. binning, resampling). Account for that now.
473-
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
466+
_flatcodes = where(mask, -1, _flatcodes)
474467

475468
full_index = pd.MultiIndex.from_product(
476469
(grouper.full_index.values for grouper in groupers),
477470
names=tuple(grouper.name for grouper in groupers),
478471
)
472+
# This will be unused when grouping by dask arrays, so skip..
473+
if not is_chunked_array(_flatcodes):
474+
midx = pd.MultiIndex.from_product(
475+
(grouper.unique_coord.data for grouper in groupers),
476+
names=tuple(grouper.name for grouper in groupers),
477+
)
478+
# Constructing an index from the product is wrong when there are missing groups
479+
# (e.g. binning, resampling). Account for that now.
480+
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
481+
group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index))
482+
else:
483+
midx = full_index
484+
group_indices = None
485+
479486
dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)
480487

481488
coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
@@ -484,7 +491,7 @@ def factorize(self) -> EncodedGroups:
484491
return EncodedGroups(
485492
codes=first_codes.copy(data=_flatcodes),
486493
full_index=full_index,
487-
group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)),
494+
group_indices=group_indices,
488495
unique_coord=Variable(dims=(dim_name,), data=midx.values),
489496
coords=coords,
490497
)

xarray/groupers.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,10 @@ def factorize(self, group: T_Group) -> EncodedGroups:
191191
self.group = group
192192

193193
if is_chunked_array(group.data) and self.labels is None:
194-
raise ValueError("When grouping by a dask array, `labels` must be passed.")
194+
raise ValueError(
195+
"When grouping by a dask array, `labels` must be passed using "
196+
"a UniqueGrouper object."
197+
)
195198
if self.labels is not None:
196199
return self._factorize_given_labels(group)
197200

xarray/tests/test_groupby.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
TimeResampler,
2323
UniqueGrouper,
2424
)
25+
from xarray.namedarray.pycompat import is_chunked_array
2526
from xarray.tests import (
2627
InaccessibleArray,
2728
assert_allclose,
2829
assert_equal,
2930
assert_identical,
3031
create_test_data,
3132
has_cftime,
33+
has_dask,
3234
has_flox,
3335
has_pandas_ge_2_2,
3436
raise_if_dask_computes,
@@ -2796,7 +2798,7 @@ def test_multiple_groupers(use_flox) -> None:
27962798

27972799
b = xr.DataArray(
27982800
np.random.RandomState(0).randn(2, 3, 4),
2799-
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])},
2801+
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})},
28002802
dims=["x", "y", "z"],
28012803
)
28022804
gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper())
@@ -2813,10 +2815,24 @@ def test_multiple_groupers(use_flox) -> None:
28132815
expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data
28142816
expected.loc[dict(x=1, xy=0)] = np.nan
28152817
expected.loc[dict(x=1, xy=2)] = newval
2816-
expected["xy"] = ("xy", ["a", "b", "c"])
2818+
expected["xy"] = ("xy", ["a", "b", "c"], {"foo": "bar"})
28172819
# TODO: is order of dims correct?
28182820
assert_identical(actual, expected.transpose("z", "x", "xy"))
28192821

2822+
if has_dask:
2823+
b["xy"] = b["xy"].chunk()
2824+
with raise_if_dask_computes():
2825+
gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]))
2826+
2827+
expected = xr.DataArray(
2828+
[[[1, 1, 1], [0, 1, 2]]] * 4,
2829+
dims=("z", "x", "xy"),
2830+
coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})},
2831+
)
2832+
assert_identical(gb.count(), expected)
2833+
assert is_chunked_array(gb.encoded.codes.data)
2834+
assert not gb.encoded.group_indices
2835+
28202836

28212837
@pytest.mark.parametrize("use_flox", [True, False])
28222838
def test_multiple_groupers_mixed(use_flox) -> None:

0 commit comments

Comments
 (0)