Skip to content

Commit

Permalink
Adapt to AR_predictions changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Aug 5, 2021
1 parent ae55ff7 commit b85757a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
3 changes: 2 additions & 1 deletion train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def main(cfg_path, exp_dir, data_dir):
da_dynamic = da_test_dynamic,
da_static = da_static,
da_bc = da_test_bc,
scaler = scaler,
scaler_transform = scaler,
scaler_inverse = scaler,
# Dataloader options
device = device,
batch_size = 50, # number of forecasts per batch
Expand Down
3 changes: 2 additions & 1 deletion train_state_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def main(cfg_path, exp_dir, data_dir, nb_models):
da_dynamic = da_test_dynamic,
da_static = da_static,
da_bc = da_test_bc,
scaler = scaler,
scaler_transform = scaler,
scaler_inverse = scaler,
# Dataloader options
device = device,
batch_size = 50, # number of forecasts per batch
Expand Down
3 changes: 2 additions & 1 deletion train_state_increment.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def main(cfg_path, exp_dir, data_dir):
da_dynamic = da_test_dynamic,
da_static = da_static,
da_bc = da_test_bc,
scaler = scaler,
scaler_transform = scaler,
scaler_inverse = scaler,
# Dataloader options
device = device,
batch_size = 50, # number of forecasts per batch
Expand Down
3 changes: 2 additions & 1 deletion train_state_short.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def main(cfg_path, exp_dir, data_dir):
da_dynamic = da_test_dynamic,
da_static = da_static,
da_bc = da_test_bc,
scaler = scaler,
scaler_transform = scaler,
scaler_inverse = scaler,
# Dataloader options
device = device,
batch_size = 50, # number of forecasts per batch
Expand Down
11 changes: 6 additions & 5 deletions train_zarr_100km.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def main(cfg_path, exp_dir, data_dir):
print_model_description(cfg)
print_dim_info(dim_info)

from modules.utils_config import get_pytorch_model
model = get_pytorch_model(module_with_custom_models = my_architectures,
model_settings = model_settings,
training_settings = training_settings)
##------------------------------------------------------------------------.
### Define the model architecture
# - TODO: improve with utils_config.get_pytorch_model
# from modules.utils_config import get_pytorch_model
# model = get_pytorch_model(module_with_custom_models = my_architectures,
# model_settings = model_settings,
# training_settings = training_settings)
DeepSphereModelClass = getattr(my_architectures, model_settings['architecture_name'])
# - Retrieve required model arguments
model_keys = ['dim_info', 'sampling', 'resolution',
Expand Down Expand Up @@ -449,7 +449,8 @@ def main(cfg_path, exp_dir, data_dir):
da_dynamic = da_test_dynamic,
da_static = da_static,
da_bc = da_test_bc,
scaler = scaler,
scaler_transform = scaler,
scaler_inverse = scaler,
# Dataloader options
device = device,
batch_size = 50, # number of forecasts per batch
Expand Down

0 comments on commit b85757a

Please sign in to comment.