Skip to content

Commit

Permalink
Add test that checks the integral (w.r.t. threshold) of DebiasedBrier…
Browse files Browse the repository at this point in the history
…Score

= CRPS (on average).

PiperOrigin-RevId: 718869486
  • Loading branch information
langmore authored and Weatherbench2 authors committed Jan 23, 2025
1 parent 2f849ab commit 06db45b
Showing 1 changed file with 91 additions and 2 deletions.
93 changes: 91 additions & 2 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@


def get_random_truth_and_forecast(
variables=('geopotential',), ensemble_size=None, seed=802701, **data_kwargs
variables=('geopotential',),
ensemble_size=None,
seed=802701,
lead_start='0 day',
lead_stop='10 day',
**data_kwargs,
):
"""Makes the tuple (truth, forecast) from kwargs."""
data_kwargs_to_use = dict(
Expand All @@ -43,7 +48,10 @@ def get_random_truth_and_forecast(
)
forecast = utils.random_like(
schema.mock_forecast_data(
ensemble_size=ensemble_size, **data_kwargs_to_use
ensemble_size=ensemble_size,
lead_start=lead_start,
lead_stop=lead_stop,
**data_kwargs_to_use,
),
seed=seed + 1,
)
Expand Down Expand Up @@ -1147,6 +1155,87 @@ def test_versus_large_ensemble_and_ensure_skipna_works(self):
atol=4 * stderr,
)

def test_integral_of_brier_score_is_crps(self):
# The integral over threshold of debiased brier score is unbiased CRPS.
truth, forecast = get_random_truth_and_forecast(
ensemble_size=2,
spatial_resolution_in_degrees=60,
time_start='2019-01-01',
time_stop='2019-12-31',
time_resolution='12 hours',
lead_start='0 day',
lead_stop='0 day',
levels=[500, 700, 850],
)

# Make forecasts (i) different mean/variance than truth, and (ii) depend on
# level.
forecast = (
forecast
+ np.abs(forecast) ** 0.2
+ xr.DataArray(
[-1, 0, 1], dims=['level'], coords={'level': forecast.level.data}
)
)

# climatology has the same stats as Normal(0, 1). So truth/forecast should
# be "perfect".
climatology_mean = xr.zeros_like(
truth.isel(time=0, drop=True).expand_dims(dayofyear=366)
)
climatology_std = xr.ones_like(
truth.isel(time=0, drop=True)
.expand_dims(
dayofyear=366,
)
.rename({'geopotential': 'geopotential_std'})
)
climatology = xr.merge([climatology_mean, climatology_std])
quantiles = np.linspace(0.005, 0.995, num=200)
threshold_objects = [
thresholds.GaussianQuantileThreshold(
climatology=climatology, quantile=q
)
for q in quantiles
]
bs = metrics.DebiasedEnsembleBrierScore(threshold_objects).compute(
forecast, truth
)['geopotential']

# Now integrate BS, with respect to the threshold. To do that, we first
# build a DataArray of thresholds corresponding to the quantiles.
precip_thresholds = []
for q, thresh in zip(quantiles, threshold_objects):
t = thresh.compute(truth)['geopotential']
# To simplify integration, we ensured threshold depends only on level.
# This "assert_array_less" checks that we did this correctly.
np.testing.assert_array_less(
t.std(['time', 'longitude', 'latitude']), 1e-4
)
precip_thresholds.append(
t.isel(time=0, longitude=0, latitude=0, drop=True).expand_dims(
quantile=[q]
)
)
precip_thresholds = xr.concat(precip_thresholds, dim='quantile')

# Second, do the integral, one level at a time.
bs = bs.assign_coords(threshold=precip_thresholds)
integrals = []
for level in bs.level:
integrals.append(bs.sel(level=level).integrate('threshold'))
bs_integral = xr.concat(integrals, dim='level')

crps = metrics.CRPS().compute(forecast, truth)['geopotential']

# Tolerance is due to (i) finite samples, and (ii) integration error. The
# integation error is going to be tiny, due to using 200 points to
# interpolate a function that we know is bounded to ≈ [-5, 5].
stderr = 1 / np.sqrt(
np.prod([v for k, v in truth.sizes.items() if k != 'level'])
)
xr.testing.assert_allclose(bs_integral, crps, atol=4 * stderr)


class EnsembleIgnoranceScoreTest(parameterized.TestCase):

Expand Down

0 comments on commit 06db45b

Please sign in to comment.