Skip to content

Commit

Permalink
Use lazy weights for climate_statistics and axis_statistics (#2346)
Browse files Browse the repository at this point in the history
  • Loading branch information
schlunma authored Feb 26, 2024
1 parent bb7866e commit c9a5982
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
10 changes: 7 additions & 3 deletions esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,10 @@ def get_time_weights(cube: Cube) -> np.ndarray | da.core.Array:
Returns
-------
np.ndarray or da.core.Array
Array of time weights for averaging.
np.ndarray or da.Array
Array of time weights for averaging. Returns a
:class:`dask.array.Array` if the input cube has lazy data; a
:class:`numpy.ndarray` otherwise.
"""
time = cube.coord('time')
Expand All @@ -408,7 +410,9 @@ def get_time_weights(cube: Cube) -> np.ndarray | da.core.Array:
)

# Extract 1D time weights (= lengths of time intervals)
time_weights = time.core_bounds()[:, 1] - time.core_bounds()[:, 0]
time_weights = time.lazy_bounds()[:, 1] - time.lazy_bounds()[:, 0]
if not cube.has_lazy_data():
time_weights = time_weights.compute()
return time_weights


Expand Down
5 changes: 4 additions & 1 deletion esmvalcore/preprocessor/_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,11 @@ def axis_statistics(

def _add_axis_stats_weights_coord(cube, coord, coord_dims):
"""Add weights for axis_statistics to cube (in-place)."""
weights = np.abs(coord.lazy_bounds()[:, 1] - coord.lazy_bounds()[:, 0])
if not cube.has_lazy_data():
weights = weights.compute()
weights_coord = AuxCoord(
np.abs(coord.core_bounds()[..., 1] - coord.core_bounds()[..., 0]),
weights,
long_name='_axis_statistics_weights_',
units=coord.units,
)
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/preprocessor/_time/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,17 @@ def test_get_time_weights():
"""Test ``get_time_weights`` for complex cube."""
cube = _make_cube()
weights = get_time_weights(cube)
assert isinstance(weights, np.ndarray)
assert weights.shape == (2, )
np.testing.assert_allclose(weights, [15.0, 30.0])


def test_get_time_weights_lazy():
"""Test ``get_time_weights`` for complex cube with lazy data."""
cube = _make_cube()
cube.data = cube.lazy_data()
weights = get_time_weights(cube)
assert isinstance(weights, da.Array)
assert weights.shape == (2, )
np.testing.assert_allclose(weights, [15.0, 30.0])

Expand Down
24 changes: 24 additions & 0 deletions tests/unit/preprocessor/_volume/test_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import tests
from esmvalcore.preprocessor._volume import (
_add_axis_stats_weights_coord,
axis_statistics,
calculate_volume,
depth_integration,
Expand Down Expand Up @@ -106,6 +107,29 @@ def setUp(self):
iris.util.guess_coord_axis(self.grid_4d_2.coord('zcoord'))
iris.util.guess_coord_axis(self.grid_4d_z.coord('zcoord'))

def test_add_axis_stats_weights_coord(self):
"""Test _add_axis_stats_weights_coord."""
assert not self.grid_4d.coords('_axis_statistics_weights_')
coord = self.grid_4d.coord('zcoord')
coord_dims = self.grid_4d.coord_dims('zcoord')
_add_axis_stats_weights_coord(self.grid_4d, coord, coord_dims)
weights_coord = self.grid_4d.coord('_axis_statistics_weights_')
assert not weights_coord.has_lazy_points()
assert weights_coord.units == 'm'
np.testing.assert_allclose(weights_coord.points, [2.5, 22.5, 225.0])

def test_add_axis_stats_weights_coord_lazy(self):
"""Test _add_axis_stats_weights_coord."""
self.grid_4d.data = self.grid_4d.lazy_data()
assert not self.grid_4d.coords('_axis_statistics_weights_')
coord = self.grid_4d.coord('zcoord')
coord_dims = self.grid_4d.coord_dims('zcoord')
_add_axis_stats_weights_coord(self.grid_4d, coord, coord_dims)
weights_coord = self.grid_4d.coord('_axis_statistics_weights_')
assert weights_coord.has_lazy_points()
assert weights_coord.units == 'm'
np.testing.assert_allclose(weights_coord.points, [2.5, 22.5, 225.0])

def test_axis_statistics_mean(self):
"""Test axis statistics with operator mean."""
data = np.ma.arange(1, 25).reshape(2, 3, 2, 2)
Expand Down

0 comments on commit c9a5982

Please sign in to comment.