Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize state diff stats in mdp datastore. Change variable name. #122

Merged
merged 10 commits into from
Feb 13, 2025
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fix bug where the inverse_softplus used in clamping caused nans in the gradients [\#123](https://github.com/mllam/neural-lam/pull/123) @SimonKamuk

- Add standardization to state diff stats from mdp datastore [\#122](https://github.com/mllam/neural-lam/pull/122) @SimonKamuk

### Maintenance
- update ci/cd testing setup to install torch version compatible with neural-lam
dependencies [\#115](https://github.com/mllam/neural-lam/pull/115), @leifdenby
Expand Down
13 changes: 7 additions & 6 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,13 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
Return the standardization (i.e. scaling to mean of 0.0 and standard
deviation of 1.0) dataarray for the given category. This should contain
a `{category}_mean` and `{category}_std` variable for each variable in
the category. For `category=="state"`, the dataarray should also
contain a `state_diff_mean` and `state_diff_std` variable for the one-
step differences of the state variables. The returned dataarray should
at least have dimensions of `({category}_feature)`, but can also
include for example `grid_index` (if the standardization is done per
grid point for example).
the category.
For `category=="state"`, the dataarray should also contain a
`state_diff_mean_standardized` and `state_diff_std_standardized`
variable for the one-step differences of the state variables.
The returned dataarray should at least have dimensions of
`({category}_feature)`, but can also include for example `grid_index`
(if the standardization is done per grid point for example).

Parameters
----------
Expand Down
22 changes: 16 additions & 6 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
each variable in the category. For `category=="state"`, the dataarray
should also contain a `state_diff_mean` and `state_diff_std` variable
for the one- step differences of the state variables.
each variable in the category.
For `category=="state"`, the dataarray should also contain a
`state_diff_mean_standardized` and `state_diff_std_standardized`
variable for the one-step differences of the state variables.

Parameters
----------
Expand All @@ -321,12 +322,21 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
stats_variables = {
f"{category}__{split}__{op}": f"{category}_{op}" for op in ops
}

ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)

# Add standardized state diff stats
if category == "state":
stats_variables.update(
{f"state__{split}__diff_{op}": f"state_diff_{op}" for op in ops}
ds_stats = ds_stats.assign(
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
**{
f"state_diff_{op}_standardized": self._ds[
f"state__{split}__diff_{op}"
]
/ ds_stats["state_std"]
for op in ops
}
)

ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
return ds_stats

@cached_property
Expand Down
17 changes: 12 additions & 5 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,10 @@ def boundary_mask(self) -> xr.DataArray:
def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
each variable in the category. For `category=="state"`, the dataarray
should also contain a `state_diff_mean` and `state_diff_std` variable
for the one- step differences of the state variables.
each variable in the category.
For `category=="state"`, the dataarray should also contain a
`state_diff_mean_standardized` and `state_diff_std_standardized`
variable for the one-step differences of the state variables.

Parameters
----------
Expand Down Expand Up @@ -769,8 +770,14 @@ def load_pickled_tensor(fn):
}

if mean_diff_values is not None and std_diff_values is not None:
variables["state_diff_mean"] = (feature_dim_name, mean_diff_values)
variables["state_diff_std"] = (feature_dim_name, std_diff_values)
variables["state_diff_mean_standardized"] = (
feature_dim_name,
mean_diff_values,
)
variables["state_diff_std_standardized"] = (
feature_dim_name,
std_diff_values,
)

ds_norm = xr.Dataset(
variables,
Expand Down
8 changes: 6 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,15 @@ def __init__(
"state_std": torch.tensor(
da_state_stats.state_std.values, dtype=torch.float32
),
# Note that the one-step-diff stats (diff_mean and diff_std) are
# for differences computed on standardized data
"diff_mean": torch.tensor(
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
da_state_stats.state_diff_mean.values, dtype=torch.float32
da_state_stats.state_diff_mean_standardized.values,
dtype=torch.float32,
),
"diff_std": torch.tensor(
da_state_stats.state_diff_std.values, dtype=torch.float32
da_state_stats.state_diff_std_standardized.values,
dtype=torch.float32,
),
}

Expand Down
15 changes: 8 additions & 7 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,13 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
Return the standardization (i.e. scaling to mean of 0.0 and standard
deviation of 1.0) dataarray for the given category. This should contain
a `{category}_mean` and `{category}_std` variable for each variable in
the category. For `category=="state"`, the dataarray should also
contain a `state_diff_mean` and `state_diff_std` variable for the one-
step differences of the state variables. The returned dataarray should
at least have dimensions of `({category}_feature)`, but can also
include for example `grid_index` (if the standardization is done per
grid point for example).
the category.
For `category=="state"`, the dataarray should also contain a
`state_diff_mean_standardized` and `state_diff_std_standardized`
variable for the one-step differences of the state variables.
The returned dataarray should at least have dimensions of
`({category}_feature)`, but can also include for example `grid_index`
(if the standardization is done per grid point for example).

Parameters
----------
Expand All @@ -292,7 +293,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:

ops = ["mean", "std"]
if category == "state":
ops += ["diff_mean", "diff_std"]
ops += ["diff_mean_standardized", "diff_std_standardized"]

for op in ops:
da_op = xr.ones_like(self.ds[f"{category}_feature"]).astype(float)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_get_vars(datastore_name):

@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_normalization_dataarray(datastore_name):
"""Check that the `datastore.get_normalization_dataa rray` method is
"""Check that the `datastore.get_normalization_dataarray` method is
implemented."""
datastore = init_datastore_example(datastore_name)

Expand All @@ -144,7 +144,12 @@ def test_get_normalization_dataarray(datastore_name):
assert isinstance(ds_stats, xr.Dataset)

if category == "state":
ops = ["mean", "std", "diff_mean", "diff_std"]
ops = [
"mean",
"std",
"diff_mean_standardized",
"diff_std_standardized",
]
elif category == "forcing":
ops = ["mean", "std"]
elif category == "static":
Expand Down
Loading