|
12 | 12 | import operator
|
13 | 13 | from abc import ABC, abstractmethod
|
14 | 14 | from collections import defaultdict
|
15 |
| -from collections.abc import Mapping, Sequence |
| 15 | +from collections.abc import Hashable, Mapping, Sequence |
16 | 16 | from dataclasses import dataclass, field
|
17 | 17 | from itertools import chain, pairwise
|
18 | 18 | from typing import TYPE_CHECKING, Any, Literal, cast
|
|
52 | 52 | "EncodedGroups",
|
53 | 53 | "Grouper",
|
54 | 54 | "Resampler",
|
| 55 | + "SeasonGrouper", |
| 56 | + "SeasonResampler", |
55 | 57 | "TimeResampler",
|
56 | 58 | "UniqueGrouper",
|
57 | 59 | ]
|
@@ -169,7 +171,26 @@ class Resampler(Grouper):
|
169 | 171 | Currently only used for TimeResampler, but could be used for SpaceResampler in the future.
|
170 | 172 | """
|
171 | 173 |
|
172 |
| - pass |
| 174 | + def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]: |
| 175 | + """ |
| 176 | + Compute chunk sizes for this resampler. |
| 177 | +
|
| 178 | + This method should be implemented by subclasses to provide appropriate |
| 179 | + chunking behavior for their specific resampling strategy. |
| 180 | +
|
| 181 | + Parameters |
| 182 | + ---------- |
| 183 | + variable : Variable |
| 184 | + The variable being chunked. |
| 185 | + dim : Hashable |
| 186 | + The name of the dimension being chunked. |
| 187 | +
|
| 188 | + Returns |
| 189 | + ------- |
| 190 | + tuple[int, ...] |
| 191 | + A tuple of chunk sizes for the dimension. |
| 192 | + """ |
| 193 | + raise NotImplementedError("Subclasses must implement compute_chunks method") |
173 | 194 |
|
174 | 195 |
|
175 | 196 | @dataclass
|
@@ -565,6 +586,49 @@ def factorize(self, group: T_Group) -> EncodedGroups:
|
565 | 586 | coords=coordinates_from_variable(unique_coord),
|
566 | 587 | )
|
567 | 588 |
|
| 589 | + def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]: |
| 590 | + """ |
| 591 | + Compute chunk sizes for this time resampler. |
| 592 | +
|
| 593 | + This method is used during chunking operations to determine appropriate |
| 594 | + chunk sizes for the given variable when using this resampler. |
| 595 | +
|
| 596 | + Parameters |
| 597 | + ---------- |
| 598 | + name : Hashable |
| 599 | + The name of the dimension being chunked. |
| 600 | + variable : Variable |
| 601 | + The variable being chunked. |
| 602 | +
|
| 603 | + Returns |
| 604 | + ------- |
| 605 | + tuple[int, ...] |
| 606 | + A tuple of chunk sizes for the dimension. |
| 607 | + """ |
| 608 | + from xarray.core.dataarray import DataArray |
| 609 | + |
| 610 | + if not _contains_datetime_like_objects(variable): |
| 611 | + raise ValueError( |
| 612 | + f"Computing chunks with {type(self)!r} only supported for datetime variables. " |
| 613 | + f"Received variable with dtype {variable.dtype!r} instead." |
| 614 | + ) |
| 615 | + |
| 616 | + chunks = ( |
| 617 | + DataArray( |
| 618 | + np.ones(variable.shape, dtype=int), |
| 619 | + dims=(dim,), |
| 620 | + coords={dim: variable}, |
| 621 | + ) |
| 622 | + .resample({dim: self}) |
| 623 | + .sum() |
| 624 | + ) |
| 625 | + # When bins (binning) or time periods are missing (resampling) |
| 626 | + # we can end up with NaNs. Drop them. |
| 627 | + if chunks.dtype.kind == "f": |
| 628 | + chunks = chunks.dropna(dim).astype(int) |
| 629 | + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) |
| 630 | + return chunks_tuple |
| 631 | + |
568 | 632 |
|
569 | 633 | def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
|
570 | 634 | # Copied from flox
|
@@ -967,5 +1031,58 @@ def get_label(year, season):
|
967 | 1031 |
|
968 | 1032 | return EncodedGroups(codes=codes, full_index=full_index)
|
969 | 1033 |
|
| 1034 | + def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]: |
| 1035 | + """ |
| 1036 | + Compute chunk sizes for this season resampler. |
| 1037 | +
|
| 1038 | + This method is used during chunking operations to determine appropriate |
| 1039 | + chunk sizes for the given variable when using this resampler. |
| 1040 | +
|
| 1041 | + Parameters |
| 1042 | + ---------- |
| 1043 | + name : Hashable |
| 1044 | + The name of the dimension being chunked. |
| 1045 | + variable : Variable |
| 1046 | + The variable being chunked. |
| 1047 | +
|
| 1048 | + Returns |
| 1049 | + ------- |
| 1050 | + tuple[int, ...] |
| 1051 | + A tuple of chunk sizes for the dimension. |
| 1052 | + """ |
| 1053 | + from xarray.core.dataarray import DataArray |
| 1054 | + |
| 1055 | + if not _contains_datetime_like_objects(variable): |
| 1056 | + raise ValueError( |
| 1057 | + f"Computing chunks with {type(self)!r} only supported for datetime variables. " |
| 1058 | + f"Received variable with dtype {variable.dtype!r} instead." |
| 1059 | + ) |
| 1060 | + |
| 1061 | + if len("".join(self.seasons)) != 12: |
| 1062 | + raise ValueError( |
| 1063 | + "Cannot rechunk with a SeasonResampler that does not cover all 12 months. " |
| 1064 | + f"Received `seasons={self.seasons!r}`." |
| 1065 | + ) |
| 1066 | + |
| 1067 | + # Create a temporary resampler that ignores drop_incomplete for chunking |
| 1068 | + # This prevents data from being silently dropped during chunking |
| 1069 | + resampler_for_chunking = type(self)(seasons=self.seasons, drop_incomplete=False) |
| 1070 | + |
| 1071 | + chunks = ( |
| 1072 | + DataArray( |
| 1073 | + np.ones(variable.shape, dtype=int), |
| 1074 | + dims=(dim,), |
| 1075 | + coords={dim: variable}, |
| 1076 | + ) |
| 1077 | + .resample({dim: resampler_for_chunking}) |
| 1078 | + .sum() |
| 1079 | + ) |
| 1080 | + # When bins (binning) or time periods are missing (resampling) |
| 1081 | + # we can end up with NaNs. Drop them. |
| 1082 | + if chunks.dtype.kind == "f": |
| 1083 | + chunks = chunks.dropna(dim).astype(int) |
| 1084 | + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) |
| 1085 | + return chunks_tuple |
| 1086 | + |
970 | 1087 | def reset(self) -> Self:
|
971 | 1088 | return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete)
|
0 commit comments