Skip to content

Commit bba476d

Browse files
langmoreWeatherbench2 authors
authored andcommitted
LongitudeScheme options added to regridding.
BUGFIX: Previous regridding did not output the right dimension values More tests added, including accuracy tests. PiperOrigin-RevId: 736568976
1 parent 66b236b commit bba476d

File tree

6 files changed

+803
-92
lines changed

6 files changed

+803
-92
lines changed

scripts/compute_probabilistic_climatological_forecasts.py

Lines changed: 167 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,36 @@
3030
2. For each ti, create a forecast (indexed by "prediction_timedelta") comprised
3131
of the historical weather starting at ti.
3232
33-
Each "ti" is selected to be a perturbation of the output init time T as
33+
Each source time "ti" is a perturbation of the output init time T:
3434
3535
* t.minute = T.minute
3636
* t.hour = T.hour
3737
* t.year ~ Uniform({CLIMATOLOGY_START_YEAR,..., CLIMATOLOGY_END_YEAR})
3838
* t.day = (T.day + δ) % [days in t.year], where the day offset δ is uniform:
3939
δ ~ Uniform(-DAY_WINDOW_SIZE // 2, DAY_WINDOW_SIZE // 2) + DAY_WINDOW_SIZE % 2
4040
41+
The (T.day + δ) % [days in t.year] step is the default behavior indicated by
42+
INITIAL_TIME_EDGE_BEHAVIOR=WRAP_YEAR. This is needed to ensure every year and
43+
dayofyear is sampled with/without replacement. If instead,
44+
INITIAL_TIME_EDGE_BEHAVIOR=REFLECT_RANGE, then t.day = T.day + δ, except at
45+
the climatology start/end boundary, where T.day is reflected back into bounds.
46+
47+
By default, every initial time has its day and year sampled independently. This
48+
means every single forecast could come from an entirely different season.
49+
SAMPLE_HOLD_DAYS provides the ability to alter this behavior, by making each
50+
realization fix the number of days between the output time and source time,
51+
(T - t).days, for SAMPLE_HOLD_DAYS days in a row. After SAMPLE_HOLD_DAYS days,
52+
each realization selects (independently) a new (T - t).days. This option is most
53+
useful when used with INITIAL_TIME_EDGE_BEHAVIOR=REFLECT_RANGE. In that case,
54+
55+
* SAMPLE_HOLD_DAYS=365 means realizations each come from a random season (year
56+
and time of year), and this random season is changed only once every 365 days.
57+
This emulates a forecast model that may or may not have the seasonal trends
58+
(e.g. ENSO) correct.
59+
* SAMPLE_HOLD_DAYS=50 emulates forecast model that may or may not have the
60+
subseasonal trends correct.
61+
62+
4163
Example Usage:
4264
4365
```
@@ -141,6 +163,32 @@
141163
' hour.'
142164
),
143165
)
166+
SAMPLE_HOLD_DAYS = flags.DEFINE_integer(
167+
'sample_hold_days',
168+
0,
169+
help=(
170+
'Non-negative multiple of INITIAL_TIME_SPACING. 0 means no hold. '
171+
'If nonzero, the total days perturbation is constant for'
172+
' this time. Warning: The "hold" means observations must'
173+
' be available SAMPLE_HOLD_DAYS days further in to the future.'
174+
),
175+
)
176+
WRAP_YEAR = 'WRAP_YEAR'
177+
REFLECT_RANGE = 'REFLECT_RANGE'
178+
INITIAL_TIME_EDGE_BEHAVIOR = flags.DEFINE_enum(
179+
'initial_time_edge_behavior',
180+
WRAP_YEAR,
181+
enum_values=[WRAP_YEAR, REFLECT_RANGE],
182+
help=(
183+
'What to do when the day perturbation would select a time before or'
184+
f' after the sampled year. "{WRAP_YEAR}" means e.g. YYYY-12-31 + 5 days'
185+
f' becomes YYYY-01-05, for ever year YYYY. "{REFLECT_RANGE}" means we'
186+
' reflect the perturbation, but only only do this at the start/end'
187+
' year. So, if start/end is 1990, 2000, then 2000-12-31 + 5 days ='
188+
' 2000-12-31 - 5 days = 2000-12-26, but 1995-12-31 + 5 days ='
189+
' 1996-01-05.'
190+
),
191+
)
144192
FORECAST_DURATION = flags.DEFINE_string(
145193
'forecast_duration', '15 days', help='Length of forecasts.'
146194
)
@@ -169,7 +217,9 @@
169217
DAY_WINDOW_SIZE = flags.DEFINE_integer(
170218
'day_window_size',
171219
10,
172-
help='Width of window (in days) to take samples from.',
220+
help=(
221+
'Width of window (in days) to take samples from. Must be in [0, 2*364].'
222+
),
173223
)
174224
ENSEMBLE_SIZE = flags.DEFINE_integer(
175225
'ensemble_size',
@@ -179,13 +229,17 @@
179229
' the same as ensemble_size = "number of possible day perturbations" x'
180230
' "number of possible years." If WITH_REPLACEMENT=False as well, this'
181231
' means every possible day and year combination will be used exactly'
182-
' once.'
232+
f' once (if INITIAL_TIME_EDGE_BEHAVIOR="{WRAP_YEAR}").'
183233
),
184234
)
185235
WITH_REPLACEMENT = flags.DEFINE_boolean(
186236
'with_replacement',
187237
True,
188-
help='Whether sampling is done with or without replacement.',
238+
help=(
239+
'Whether sampling is done with or without replacement. Warning: If'
240+
f' INITIAL_TIME_EDGE_BEHAVIOR="{REFLECT_RANGE}", then some samples may'
241+
' be repeated near the climatological boundary.'
242+
),
189243
)
190244
SEED = flags.DEFINE_integer(
191245
'seed', 802701, help='Seed for the random number generator.'
@@ -307,6 +361,8 @@ def _get_sampled_init_times(
307361
day_window_size: int,
308362
ensemble_size: int,
309363
with_replacement: bool,
364+
sample_hold_days: int,
365+
initial_time_edge_behavior: str,
310366
seed: int,
311367
) -> np.ndarray:
312368
"""For each output time, get the times to sample from observations.
@@ -333,13 +389,22 @@ def _get_sampled_init_times(
333389
day_window_size: Size of window, in dayofyear, to grab samples.
334390
ensemble_size: Number of samples (per init time) to grab.
335391
with_replacement: Whether to sample with or without replacement.
392+
sample_hold_days: How long consecutive initial times use the same
393+
perturbation. 0 means switch perturbations every consecutive init time.
394+
initial_time_edge_behavior: How to deal with perturbations that move the
395+
sampled day outside of sampled year.
336396
seed: Integer seed for the RNG.
337397
338398
Returns:
339399
Shape [ensemble_size, len(output_times)] array of np.datetime64[ns].
340400
"""
341401
rng = np.random.default_rng(seed)
342402

403+
if day_window_size > 2 * 364:
404+
# This complicates the REFLECT_RANGE behavior, and no sensible human would
405+
# want this.
406+
raise ValueError(f'{day_window_size=} > 2 * 364, which is not allowed.')
407+
343408
# The scheme below samples uniformly over initial day (ignoring leap years).
344409
# Conceptually, think of each climatology year as a circle. The days
345410
# [0, ..., 365] with 0 and 365 (or 366) connected. This sampler
@@ -417,16 +482,36 @@ def _get_sampled_init_times(
417482
)
418483
# End of get sampled years and day_perturbations.
419484

420-
# If output_times is near the start or end of the year, we want the
421-
# perturbation to wrap around and find a date within the same year.
422485
dayofyears = output_times.dayofyear.values + day_perturbations
423-
for year in range(climatology_start_year, climatology_end_year + 1):
424-
mask = years == year
425-
dayofyears[mask] = (dayofyears[mask] - 1) % (
426-
365 + calendar.isleap(year)
427-
) + 1
428486

429-
return (
487+
if initial_time_edge_behavior == WRAP_YEAR:
488+
for year in range(climatology_start_year, climatology_end_year + 1):
489+
mask = years == year
490+
days_in_this_year = 365 + calendar.isleap(year)
491+
dayofyears[mask] = (dayofyears[mask] - 1) % days_in_this_year + 1
492+
493+
elif initial_time_edge_behavior == REFLECT_RANGE:
494+
for year in {climatology_start_year, climatology_end_year}:
495+
mask = years == year
496+
days_in_this_year = 365 + calendar.isleap(year)
497+
if year == climatology_start_year:
498+
# Transform e.g. 1 --> 1, 0 --> 2, -1 --> 3
499+
dayofyears[mask] = np.where(
500+
dayofyears[mask] >= 1,
501+
dayofyears[mask],
502+
np.abs(dayofyears[mask]) + 2,
503+
)
504+
elif year == climatology_end_year:
505+
dayofyears[mask] = np.where(
506+
dayofyears[mask] <= days_in_this_year,
507+
dayofyears[mask],
508+
# If d > 365, set to 2*365 - d = 365 - (d - 365)
509+
2 * days_in_this_year - dayofyears[mask],
510+
)
511+
else:
512+
raise ValueError(f'Unhandled {initial_time_edge_behavior=}')
513+
514+
sampled_times = (
430515
# Years is always defined in years since the epoch.
431516
np.array(years - 1970, dtype='datetime64[Y]')
432517
# Add daysofyears - 1 to year, since e.g. if dayofyear = 1, then we will
@@ -435,6 +520,56 @@ def _get_sampled_init_times(
435520
+ np.array(output_times.hour, dtype='timedelta64[h]')
436521
).astype('datetime64[ns]')
437522

523+
if sample_hold_days:
524+
output_time_strides = set(output_times.diff()[1:])
525+
if len(output_time_strides) > 1:
526+
raise ValueError(
527+
f'Cannot sample hold with more than one {output_time_strides=}'
528+
)
529+
output_time_stride = output_time_strides.pop()
530+
hold_dt = pd.Timedelta(f'{sample_hold_days}d')
531+
hold_stride = hold_dt // output_time_stride
532+
if output_time_stride * hold_stride != hold_dt:
533+
raise ValueError(
534+
f'{sample_hold_days=} was not a multiple of {output_time_stride=}'
535+
)
536+
hold_idx = np.repeat(
537+
# E.g. hold_idx = [0, 0, ..., 0, 1, 1, ..., 1, 2, ...]
538+
np.arange(len(output_times) // hold_stride + 1)[:, np.newaxis],
539+
hold_stride,
540+
axis=1,
541+
).ravel()[: len(output_times)]
542+
543+
# Convert np datetimes into δ days, sample-hold, then add back to datetimes.
544+
delta_days = np.array(
545+
pd.to_timedelta((sampled_times - output_times.values).ravel()).days,
546+
dtype=np.int64,
547+
).reshape(sampled_times.shape)
548+
549+
delta_days = np.take(delta_days, hold_idx, axis=1)
550+
sampled_times = output_times.values + np.array(
551+
delta_days, dtype='timedelta64[D]'
552+
)
553+
554+
return sampled_times
555+
556+
557+
def _check_times_in_dataset(
558+
times: np.ndarray | pd.DatetimeIndex, ds: xr.Dataset, user_err: bool
559+
) -> None:
560+
"""Checks that `times` are in `ds` and gives a nice error if not."""
561+
missing_times = pd.to_datetime(times).difference(ds[TIME_DIM.value])
562+
if missing_times.size and user_err:
563+
raise flags.ValidationError(
564+
'Time flags (CLIMATOLOGY_START_YEAR, CLIMATOLOGY_END_YEAR,'
565+
' TIMEDELTA_SPACING, SAMPLE_HOLD_DAYS) asked for values in INPUT that '
566+
f'are not available. {missing_times=}.'
567+
)
568+
elif missing_times.size and not user_err:
569+
raise AssertionError(
570+
f'Calculation of times needed is wrong. File a bug! {missing_times=}'
571+
)
572+
438573

439574
def _check_input_spacing_and_time_flags(input_ds: xr.Dataset) -> None:
440575
"""Validates input spacing, TIMEDELTA_SPACING, and INITIAL_TIME_SPACING."""
@@ -544,8 +679,13 @@ def main(argv: abc.Sequence[str]) -> None:
544679
)
545680

546681
# Select all needed samples from INPUT.
547-
time_buffer = pd.to_timedelta(FORECAST_DURATION.value) + pd.to_timedelta(
548-
f'{DAY_WINDOW_SIZE.value}d'
682+
max_num_leapyears = np.ceil(
683+
CLIMATOLOGY_END_YEAR.value - CLIMATOLOGY_START_YEAR.value
684+
)
685+
time_buffer = (
686+
pd.to_timedelta(FORECAST_DURATION.value)
687+
+ pd.to_timedelta(f'{DAY_WINDOW_SIZE.value}d')
688+
+ pd.to_timedelta(f'{SAMPLE_HOLD_DAYS.value + int(max_num_leapyears)}d')
549689
)
550690
sample_spacing = min(
551691
ONE_DAY,
@@ -557,13 +697,7 @@ def main(argv: abc.Sequence[str]) -> None:
557697
pd.to_datetime(f'{CLIMATOLOGY_END_YEAR.value}-12-31') + time_buffer,
558698
freq=sample_spacing,
559699
)
560-
missing_times = times_needed_for_sampling.difference(input_ds[TIME_DIM.value])
561-
if missing_times.size:
562-
raise flags.ValidationError(
563-
'Time flags (CLIMATOLOGY_START_YEAR, CLIMATOLOGY_END_YEAR,'
564-
' TIMEDELTA_SPACING) asked for values in INPUT that are not available.'
565-
f' {missing_times=}.'
566-
)
700+
_check_times_in_dataset(times_needed_for_sampling, input_ds, user_err=True)
567701
input_ds = input_ds.sel({TIME_DIM.value: times_needed_for_sampling})
568702

569703
# Define output times and the template.
@@ -608,20 +742,24 @@ def main(argv: abc.Sequence[str]) -> None:
608742
# _get_sampled_init_times returns shape [ensemble_size, n_times] array of
609743
# np.datetime64. These are use as initial times for output samples.
610744
# Ravel it into shape [ensemble_size * n_times] set of times.
611-
output_init_times,
612-
CLIMATOLOGY_START_YEAR.value,
613-
CLIMATOLOGY_END_YEAR.value,
614-
DAY_WINDOW_SIZE.value,
615-
ensemble_size,
616-
WITH_REPLACEMENT.value,
617-
SEED.value,
745+
output_times=output_init_times,
746+
climatology_start_year=CLIMATOLOGY_START_YEAR.value,
747+
climatology_end_year=CLIMATOLOGY_END_YEAR.value,
748+
day_window_size=DAY_WINDOW_SIZE.value,
749+
ensemble_size=ensemble_size,
750+
with_replacement=WITH_REPLACEMENT.value,
751+
initial_time_edge_behavior=INITIAL_TIME_EDGE_BEHAVIOR.value,
752+
sample_hold_days=SAMPLE_HOLD_DAYS.value,
753+
seed=SEED.value,
618754
).ravel()
619755

620756
def sampled_times_for_timedelta(timedelta: pd.Timedelta) -> np.ndarray:
621757
"""Times to grab from input for forecasts at this timedelta."""
622758
# Simply add the timedelta to the sampled_init_times, ensuring the forecasts
623759
# are continuous in time.
624-
return sampled_init_times + timedelta.to_numpy()
760+
times = sampled_init_times + timedelta.to_numpy()
761+
_check_times_in_dataset(times.ravel(), input_ds, user_err=False)
762+
return times
625763

626764
# init_time_offsets[i] is the (init_time, realization) offset to use with
627765
# sampled_init_times[i].

0 commit comments

Comments
 (0)