Skip to content

Commit 98732e7

Browse files
dhruvak001DHRUVA KUMAR KAUSHALpre-commit-ci[bot]dcherian
authored
Support rechunking to seasonal frequency with SeasonalResampler (#10519)
Co-authored-by: DHRUVA KUMAR KAUSHAL <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent e3cd724 commit 98732e7

File tree

6 files changed

+263
-38
lines changed

6 files changed

+263
-38
lines changed

doc/api/groupby.rst

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,18 @@ Grouper Objects
7979
:toctree: ../generated/
8080

8181
groupers.BinGrouper
82-
groupers.UniqueGrouper
83-
groupers.TimeResampler
8482
groupers.SeasonGrouper
83+
groupers.UniqueGrouper
84+
85+
86+
Resampler Objects
87+
-----------------
88+
89+
.. autosummary::
90+
:toctree: ../generated/
91+
8592
groupers.SeasonResampler
93+
groupers.SeasonResampler.compute_chunks
94+
95+
groupers.TimeResampler
96+
groupers.TimeResampler.compute_chunks

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ v2025.08.1 (unreleased)
1212

1313
New Features
1414
~~~~~~~~~~~~
15+
- Support rechunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10519`).
16+
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.
1517
- Add convenience methods to :py:class:`~xarray.Coordinates` (:pull:`10318`)
1618
By `Justus Magin <https://github.com/keewis>`_.
1719
- Added :py:func:`load_datatree` for loading ``DataTree`` objects into memory
@@ -157,7 +159,7 @@ Bug fixes
157159
creates extra variables that don't match the provided coordinate names, instead
158160
of silently ignoring them. The error message suggests using the factory method
159161
pattern with :py:meth:`xarray.Coordinates.from_xindex` and
160-
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`).
162+
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`, :pull:`10503`).
161163
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.
162164

163165
Documentation

xarray/core/dataset.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,13 +2486,16 @@ def chunk(
24862486
sizes along that dimension will not be updated; non-dask arrays will be
24872487
converted into dask arrays with a single block.
24882488
2489-
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
2489+
Along datetime-like dimensions, a :py:class:`Resampler` object
2490+
(e.g. :py:class:`groupers.TimeResampler` or :py:class:`groupers.SeasonResampler`)
2491+
is also accepted.
24902492
24912493
Parameters
24922494
----------
2493-
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
2495+
chunks : int, tuple of int, "auto" or mapping of hashable to int or a Resampler, optional
24942496
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
2495-
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
2497+
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}`` or
2498+
``{"time": SeasonResampler(["DJF", "MAM", "JJA", "SON"])}``.
24962499
name_prefix : str, default: "xarray-"
24972500
Prefix for the name of any new dask arrays.
24982501
token : str, optional
@@ -2527,8 +2530,7 @@ def chunk(
25272530
xarray.unify_chunks
25282531
dask.array.from_array
25292532
"""
2530-
from xarray.core.dataarray import DataArray
2531-
from xarray.groupers import TimeResampler
2533+
from xarray.groupers import Resampler
25322534

25332535
if chunks is None and not chunks_kwargs:
25342536
warnings.warn(
@@ -2556,41 +2558,29 @@ def chunk(
25562558
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
25572559
)
25582560

2559-
def _resolve_frequency(
2560-
name: Hashable, resampler: TimeResampler
2561-
) -> tuple[int, ...]:
2561+
def _resolve_resampler(name: Hashable, resampler: Resampler) -> tuple[int, ...]:
25622562
variable = self._variables.get(name, None)
25632563
if variable is None:
25642564
raise ValueError(
2565-
f"Cannot chunk by resampler {resampler!r} for virtual variables."
2565+
f"Cannot chunk by resampler {resampler!r} for virtual variable {name!r}."
25662566
)
2567-
elif not _contains_datetime_like_objects(variable):
2567+
if variable.ndim != 1:
25682568
raise ValueError(
2569-
f"chunks={resampler!r} only supported for datetime variables. "
2570-
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
2569+
f"chunks={resampler!r} only supported for 1D variables. "
2570+
f"Received variable {name!r} with {variable.ndim} dimensions instead."
25712571
)
2572-
2573-
assert variable.ndim == 1
2574-
chunks = (
2575-
DataArray(
2576-
np.ones(variable.shape, dtype=int),
2577-
dims=(name,),
2578-
coords={name: variable},
2572+
newchunks = resampler.compute_chunks(variable, dim=name)
2573+
if sum(newchunks) != variable.shape[0]:
2574+
raise ValueError(
2575+
f"Logic bug in rechunking variable {name!r} using {resampler!r}. "
2576+
"New chunks tuple does not match size of data. Please open an issue."
25792577
)
2580-
.resample({name: resampler})
2581-
.sum()
2582-
)
2583-
# When bins (binning) or time periods are missing (resampling)
2584-
# we can end up with NaNs. Drop them.
2585-
if chunks.dtype.kind == "f":
2586-
chunks = chunks.dropna(name).astype(int)
2587-
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
2588-
return chunks_tuple
2578+
return newchunks
25892579

25902580
chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
25912581
name: (
2592-
_resolve_frequency(name, chunks)
2593-
if isinstance(chunks, TimeResampler)
2582+
_resolve_resampler(name, chunks)
2583+
if isinstance(chunks, Resampler)
25942584
else chunks
25952585
)
25962586
for name, chunks in chunks_mapping.items()

xarray/core/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from xarray.core.indexes import Index, Indexes
3333
from xarray.core.utils import Frozen
3434
from xarray.core.variable import IndexVariable, Variable
35-
from xarray.groupers import Grouper, TimeResampler
35+
from xarray.groupers import Grouper, Resampler
3636
from xarray.structure.alignment import Aligner
3737

3838
GroupInput: TypeAlias = (
@@ -201,7 +201,7 @@ def copy(
201201
# FYI in some cases we don't allow `None`, which this doesn't take account of.
202202
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
203203
T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051
204-
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
204+
T_ChunkDimFreq: TypeAlias = Union["Resampler", T_ChunkDim]
205205
T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq]
206206
# We allow the tuple form of this (though arguably we could transition to named dims only)
207207
T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]

xarray/groupers.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import operator
1313
from abc import ABC, abstractmethod
1414
from collections import defaultdict
15-
from collections.abc import Mapping, Sequence
15+
from collections.abc import Hashable, Mapping, Sequence
1616
from dataclasses import dataclass, field
1717
from itertools import chain, pairwise
1818
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -52,6 +52,8 @@
5252
"EncodedGroups",
5353
"Grouper",
5454
"Resampler",
55+
"SeasonGrouper",
56+
"SeasonResampler",
5557
"TimeResampler",
5658
"UniqueGrouper",
5759
]
@@ -169,7 +171,26 @@ class Resampler(Grouper):
169171
Currently only used for TimeResampler, but could be used for SpaceResampler in the future.
170172
"""
171173

172-
pass
174+
def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
175+
"""
176+
Compute chunk sizes for this resampler.
177+
178+
This method should be implemented by subclasses to provide appropriate
179+
chunking behavior for their specific resampling strategy.
180+
181+
Parameters
182+
----------
183+
variable : Variable
184+
The variable being chunked.
185+
dim : Hashable
186+
The name of the dimension being chunked.
187+
188+
Returns
189+
-------
190+
tuple[int, ...]
191+
A tuple of chunk sizes for the dimension.
192+
"""
193+
raise NotImplementedError("Subclasses must implement compute_chunks method")
173194

174195

175196
@dataclass
@@ -565,6 +586,49 @@ def factorize(self, group: T_Group) -> EncodedGroups:
565586
coords=coordinates_from_variable(unique_coord),
566587
)
567588

589+
def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
590+
"""
591+
Compute chunk sizes for this time resampler.
592+
593+
This method is used during chunking operations to determine appropriate
594+
chunk sizes for the given variable when using this resampler.
595+
596+
Parameters
597+
----------
598+
name : Hashable
599+
The name of the dimension being chunked.
600+
variable : Variable
601+
The variable being chunked.
602+
603+
Returns
604+
-------
605+
tuple[int, ...]
606+
A tuple of chunk sizes for the dimension.
607+
"""
608+
from xarray.core.dataarray import DataArray
609+
610+
if not _contains_datetime_like_objects(variable):
611+
raise ValueError(
612+
f"Computing chunks with {type(self)!r} only supported for datetime variables. "
613+
f"Received variable with dtype {variable.dtype!r} instead."
614+
)
615+
616+
chunks = (
617+
DataArray(
618+
np.ones(variable.shape, dtype=int),
619+
dims=(dim,),
620+
coords={dim: variable},
621+
)
622+
.resample({dim: self})
623+
.sum()
624+
)
625+
# When bins (binning) or time periods are missing (resampling)
626+
# we can end up with NaNs. Drop them.
627+
if chunks.dtype.kind == "f":
628+
chunks = chunks.dropna(dim).astype(int)
629+
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
630+
return chunks_tuple
631+
568632

569633
def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
570634
# Copied from flox
@@ -967,5 +1031,58 @@ def get_label(year, season):
9671031

9681032
return EncodedGroups(codes=codes, full_index=full_index)
9691033

1034+
def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
1035+
"""
1036+
Compute chunk sizes for this season resampler.
1037+
1038+
This method is used during chunking operations to determine appropriate
1039+
chunk sizes for the given variable when using this resampler.
1040+
1041+
Parameters
1042+
----------
1043+
name : Hashable
1044+
The name of the dimension being chunked.
1045+
variable : Variable
1046+
The variable being chunked.
1047+
1048+
Returns
1049+
-------
1050+
tuple[int, ...]
1051+
A tuple of chunk sizes for the dimension.
1052+
"""
1053+
from xarray.core.dataarray import DataArray
1054+
1055+
if not _contains_datetime_like_objects(variable):
1056+
raise ValueError(
1057+
f"Computing chunks with {type(self)!r} only supported for datetime variables. "
1058+
f"Received variable with dtype {variable.dtype!r} instead."
1059+
)
1060+
1061+
if len("".join(self.seasons)) != 12:
1062+
raise ValueError(
1063+
"Cannot rechunk with a SeasonResampler that does not cover all 12 months. "
1064+
f"Received `seasons={self.seasons!r}`."
1065+
)
1066+
1067+
# Create a temporary resampler that ignores drop_incomplete for chunking
1068+
# This prevents data from being silently dropped during chunking
1069+
resampler_for_chunking = type(self)(seasons=self.seasons, drop_incomplete=False)
1070+
1071+
chunks = (
1072+
DataArray(
1073+
np.ones(variable.shape, dtype=int),
1074+
dims=(dim,),
1075+
coords={dim: variable},
1076+
)
1077+
.resample({dim: resampler_for_chunking})
1078+
.sum()
1079+
)
1080+
# When bins (binning) or time periods are missing (resampling)
1081+
# we can end up with NaNs. Drop them.
1082+
if chunks.dtype.kind == "f":
1083+
chunks = chunks.dropna(dim).astype(int)
1084+
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
1085+
return chunks_tuple
1086+
9701087
def reset(self) -> Self:
9711088
return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete)

0 commit comments

Comments
 (0)