Skip to content

Commit 762df72

Browse files
Merge pull request #253 from PolymathicAI/merge_public_expedite_internel
Merge public expedite internel
2 parents f0c9082 + c290d58 commit 762df72

File tree

6 files changed

+199
-11
lines changed

6 files changed

+199
-11
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
name: Bug Report
2+
description: Create a report to help us reproduce and fix the bug
3+
title: "[Bug]: "
4+
labels: ["bug"]
5+
6+
body:
7+
- type: markdown
8+
attributes:
9+
value: >
10+
Thank you for taking the time to file a bug report.
11+
Before creating a new issue, you can have a quick look to the [FAQ](https://github.com/PolymathicAI/the_well/discussions/categories/q-a?discussions_q=category%3AQ%26A+) and [existing issues](https://github.com/PolymathicAI/the_well/issues).
12+
- type: textarea
13+
attributes:
14+
label: "Describe the issue:"
15+
placeholder: |
16+
<< your issue description here >>
17+
validations:
18+
required: true
19+
- type: textarea
20+
attributes:
21+
label: "Code to reproduce the issue:"
22+
description: >
23+
A short code example that reproduces the problem/missing feature.
24+
It should be self-contained.
25+
placeholder: |
26+
<< your code here >>
27+
render: python
28+
validations:
29+
required: true
30+
- type: textarea
31+
attributes:
32+
label: "Version"
33+
description: |
34+
Which version of the Well are you using?
35+
You can obtain the version by running the following command:
36+
```sh
37+
python -c "import the_well; print(the_well.__version__)"
38+
```
39+
placeholder: |
40+
<< your version here >>
41+
validations:
42+
required: true
43+
- type: textarea
44+
attributes:
45+
label: "Environment"
46+
description: |
47+
Which environment are you using? List the packages you have installed along the Well.
48+
In case you use pip, you can obtain the list of installed packages by running the following command:
49+
```sh
50+
pip freeze
51+
```
52+
placeholder: |
53+
<< your environment here >>
54+
validations:
55+
required: true
56+
- type: textarea
57+
attributes:
58+
label: "Context for the issue:"
59+
description: |
60+
Please explain how this issue affects your intended use of the Well.
61+
You can also provide additional context that you think might be relevant.
62+
placeholder: |
63+
<< your explanation here >>
64+
validations:
65+
required: false
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../datasets/rayleigh_benard_uniform/README.md

tests/data/test_normalization.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import math
2+
3+
import torch
4+
5+
from the_well.data.normalization import RMSNormalization, ZScoreNormalization
6+
7+
8+
def test_zscore_normalization():
9+
"""Test the ZScoreNormalization actually provides the correct normalization.
10+
We consider fields whose mean and std are given by a linear function of the field index.
11+
12+
"""
13+
n_fields = 4
14+
h = 64
15+
w = 64
16+
t = 10
17+
batch_size = 64
18+
tol = 1e-2
19+
std = torch.arange(n_fields) + 1.0
20+
mean = torch.arange(n_fields)
21+
delta_mean = torch.zeros_like(mean)
22+
delta_std = math.sqrt(2) * std
23+
stats = {
24+
"mean": {f"field_{i}": mean[i] for i in range(n_fields)},
25+
"std": {f"field_{i}": std[i] for i in range(n_fields)},
26+
"mean_delta": {f"field_{i}": delta_mean[i] for i in range(n_fields)},
27+
"std_delta": {f"field_{i}": delta_std[i] for i in range(n_fields)},
28+
}
29+
normalization = ZScoreNormalization(
30+
stats=stats,
31+
core_field_names=[f"field_{i}" for i in range(n_fields)],
32+
core_constant_field_names=[],
33+
)
34+
35+
input_tensor = std * torch.randn(batch_size, h, w, t, n_fields) + mean
36+
delta_input_tensor = input_tensor[..., 1:, :] - input_tensor[..., :-1, :]
37+
for i in range(n_fields):
38+
normalized_tensor = normalization.normalize(input_tensor[..., i], f"field_{i}")
39+
assert normalized_tensor.shape == (batch_size, h, w, t)
40+
assert torch.allclose(
41+
torch.mean(normalized_tensor), torch.tensor(0.0), atol=tol
42+
)
43+
assert torch.allclose(torch.std(normalized_tensor), torch.tensor(1.0), atol=tol)
44+
45+
normalized_delta_tensor = normalization.delta_normalize(
46+
delta_input_tensor[..., i], f"field_{i}"
47+
)
48+
assert normalized_delta_tensor.shape == (batch_size, h, w, t - 1)
49+
assert torch.allclose(
50+
torch.mean(normalized_delta_tensor), torch.tensor(0.0), atol=tol
51+
)
52+
assert torch.allclose(
53+
torch.std(normalized_delta_tensor), torch.tensor(1.0), atol=tol
54+
)
55+
56+
57+
def test_rms_normalization():
58+
"""Test the RMSNormalization actually provides the correct normalization.
59+
We consider fields whose mean and std are given by a linear function of the field index.
60+
"""
61+
n_fields = 4
62+
h = 64
63+
w = 64
64+
t = 10
65+
batch_size = 64
66+
tol = 1e-2
67+
std = torch.arange(n_fields) + 1.0
68+
mean = torch.arange(n_fields)
69+
delta_std = math.sqrt(2) * std
70+
stats = {
71+
"rms": {f"field_{i}": std[i] for i in range(n_fields)},
72+
"rms_delta": {f"field_{i}": delta_std[i] for i in range(n_fields)},
73+
}
74+
normalization = RMSNormalization(
75+
stats=stats,
76+
core_field_names=[f"field_{i}" for i in range(n_fields)],
77+
core_constant_field_names=[],
78+
)
79+
80+
input_tensor = std * torch.randn(batch_size, h, w, t, n_fields) + mean
81+
delta_input_tensor = input_tensor[..., 1:, :] - input_tensor[..., :-1, :]
82+
for i in range(n_fields):
83+
normalized_tensor = normalization.normalize(input_tensor[..., i], f"field_{i}")
84+
assert normalized_tensor.shape == (batch_size, h, w, t)
85+
assert torch.allclose(
86+
torch.mean(normalized_tensor),
87+
mean[i].float() / std[i].float(),
88+
atol=tol,
89+
)
90+
assert torch.allclose(
91+
torch.std(normalized_tensor),
92+
torch.tensor(1.0),
93+
atol=tol,
94+
)
95+
96+
normalized_delta_tensor = normalization.delta_normalize(
97+
delta_input_tensor[..., i], f"field_{i}"
98+
)
99+
assert normalized_delta_tensor.shape == (batch_size, h, w, t - 1)
100+
assert torch.allclose(
101+
torch.mean(normalized_delta_tensor), torch.tensor(0.0), atol=tol
102+
)
103+
assert torch.allclose(
104+
torch.std(normalized_delta_tensor), torch.tensor(1.0), atol=tol
105+
)

the_well/benchmark/trainer/training.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,11 @@ def validation_loop(
312312
loss_dict = {}
313313
time_logs = {}
314314
count = 0
315-
denom = len(dataloader) if full else self.short_validation_length
315+
denom = (
316+
len(dataloader)
317+
if full
318+
else min(self.short_validation_length, len(dataloader))
319+
)
316320
with torch.autocast(
317321
self.device.type, enabled=self.enable_amp, dtype=self.amp_type
318322
):
@@ -398,7 +402,7 @@ def train_one_epoch(self, epoch: int, dataloader: DataLoader) -> float:
398402
backward_time = time.time() - batch_start - forward_time - batch_time
399403
total_time = time.time() - batch_start
400404
logger.info(
401-
f"Epoch {epoch}, Batch {i+1}/{len(dataloader)}: loss {loss.item()}, total_time {total_time}, batch time {batch_time}, forward time {forward_time}, backward time {backward_time}"
405+
f"Epoch {epoch}, Batch {i + 1}/{len(dataloader)}: loss {loss.item()}, total_time {total_time}, batch time {batch_time}, forward time {forward_time}, backward time {backward_time}"
402406
)
403407
batch_start = time.time()
404408
train_logs["time_per_train_iter"] = (time.time() - start_time) / len(dataloader)
@@ -458,6 +462,7 @@ def train(self):
458462
self.save_model(
459463
epoch, val_loss, os.path.join(self.checkpoint_folder, "best.pt")
460464
)
465+
self.best_val_loss = val_loss
461466
# Check if time for expensive validation - periodic or final
462467
if epoch % self.rollout_val_frequency == 0 or (epoch == self.max_epoch):
463468
logger.info(

the_well/benchmark/utils/experiment_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@ def configure_experiment(
9898
)
9999
folder_path = osp.join(experiment_folder, "extended_config.yaml")
100100
if osp.isfile(checkpoint_path):
101-
logger.info(f"Config file exists relative to checkpoint override provided, \
102-
using config file {checkpoint_path}")
101+
logger.info(
102+
f"Config file exists relative to checkpoint override provided, \
103+
using config file {checkpoint_path}"
104+
)
103105
elif osp.isfile(folder_path):
104-
logger.warn(f"Config file not found in checkpoint override path. \
106+
logger.warn(
107+
f"Config file not found in checkpoint override path. \
105108
Found in experiment folder, using config file {folder_path}. \
106109
This could lead to weight compatibility issues if the checkpoints do not align with \
107-
the specified folder.")
110+
the specified folder."
111+
)
108112
else:
109113
logger.warn(
110114
"Checkpoint override provided, but config file not found in checkpoint override path \

the_well/data/datasets.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class WellDataset(Dataset):
136136
Whether to normalize data in the dataset
137137
normlization_type:
138138
What type of dataset normalization. Callable Options: ZSCORE and RMS
139+
max_rollout_steps:
140+
Maximum number of output steps to return in a single sample. Return the full trajectory if larger than its actual length.
139141
n_steps_input:
140142
Number of steps to include in each sample
141143
n_steps_output:
@@ -583,7 +585,9 @@ def _pad_axes(
583585
expand_dims = expand_dims + (1,) * tensor_order
584586
return torch.tile(field_data, expand_dims)
585587

586-
def _reconstruct_fields(self, file, cache, sample_idx, time_idx, n_steps, dt):
588+
def _reconstruct_fields(
589+
self, file: h5.File, cache, sample_idx, time_idx, n_steps, dt
590+
):
587591
"""Reconstruct space fields starting at index sample_idx, time_idx, with
588592
n_steps and dt stride."""
589593
variable_fields = {0: {}, 1: {}, 2: {}}
@@ -634,7 +638,9 @@ def _reconstruct_fields(self, file, cache, sample_idx, time_idx, n_steps, dt):
634638

635639
return (variable_fields, constant_fields)
636640

637-
def _reconstruct_scalars(self, file, cache, sample_idx, time_idx, n_steps, dt):
641+
def _reconstruct_scalars(
642+
self, file: h5.File, cache, sample_idx, time_idx, n_steps, dt
643+
):
638644
"""Reconstruct scalar values (not fields) starting at index sample_idx, time_idx, with
639645
n_steps and dt stride."""
640646
variable_scalars = {}
@@ -670,7 +676,9 @@ def _reconstruct_scalars(self, file, cache, sample_idx, time_idx, n_steps, dt):
670676

671677
return (variable_scalars, constant_scalars)
672678

673-
def _reconstruct_grids(self, file, cache, sample_idx, time_idx, n_steps, dt):
679+
def _reconstruct_grids(
680+
self, file: h5.File, cache, sample_idx, time_idx, n_steps, dt
681+
):
674682
"""Reconstruct grid values starting at index sample_idx, time_idx, with
675683
n_steps and dt stride."""
676684
# Time
@@ -705,7 +713,7 @@ def _reconstruct_grids(self, file, cache, sample_idx, time_idx, n_steps, dt):
705713
self._check_cache(cache, "space_grid", space_grid)
706714
return space_grid, time_grid
707715

708-
def _padding_bcs(self, file, cache, sample_idx, time_idx, n_steps, dt):
716+
def _padding_bcs(self, file: h5.File, cache, sample_idx, time_idx, n_steps, dt):
709717
"""Handles BC case where BC corresponds to a specific padding type
710718
711719
Note/TODO - currently assumes boundaries to be axis-aligned and cover the entire
@@ -753,7 +761,7 @@ def _padding_bcs(self, file, cache, sample_idx, time_idx, n_steps, dt):
753761
self._check_cache(cache, "boundary_output", boundary_output)
754762
return boundary_output
755763

756-
def _reconstruct_bcs(self, file, cache, sample_idx, time_idx, n_steps, dt):
764+
def _reconstruct_bcs(self, file: h5.File, cache, sample_idx, time_idx, n_steps, dt):
757765
"""Needs work to support arbitrary BCs.
758766
759767
Currently supports finite set of boundary condition types that describe

0 commit comments

Comments
 (0)