Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,4 @@ notebooks/plots/
notebooks/rollouts/

scripts/run_single_*.txt
scripts/srun_outs_ft/
2 changes: 2 additions & 0 deletions climatem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
num_months_aggregated: List[int] = [
1
], # Aggregate num_months_aggregated months i.e. if you want yearly temporal resolution set this param to [12]
**kwargs, # accept any new keys in the parameter configs (e.g returned by the class)
):
self.data_dir = data_dir
self.climateset_data = climateset_data
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
patience_post_thresh: int = 50, # NOT SURE: if mapping converges before patience, and for patience_post_thresh it's stable, then optimize everything
valid_freq: int = 5, # get validation metrics every valid_freq iteration
# here valid_freq is critical for updating the parameters of the ALM method as they get updated every valid_freq
**kwargs, # accept any new keys in the parameter configs
):
self.ratio_train = ratio_train
self.ratio_valid = 1 - self.ratio_train
Expand Down
21 changes: 20 additions & 1 deletion climatem/model/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def trace_handler(p):
# Todo propagate the path!
if not self.plot_params.savar:
self.plotter.save_coordinates_and_adjacency_matrices(self)

torch.save(self.model.state_dict(), self.save_path / f"model_{self.iteration}.pth")

# try to use the accelerator.save function here
Expand Down Expand Up @@ -553,6 +554,7 @@ def train_step(self): # noqa: C901
# we have to take care here to make sure that we have the right tensors with requires_grad
for k in range(self.future_timesteps):
nll_bis, recons_bis, kl_bis, y_pred_recons = self.get_nll(x_bis, y[:, k], z)
# sz: nll_bis (positive): -elbo (elbo is negative, -elbo is positive) we want to minimize, recons_bis (negative number): construction we want to maximize,
nll += (self.optim_params.loss_decay_future_timesteps**k) * nll_bis
recons += (self.optim_params.loss_decay_future_timesteps**k) * recons_bis
kl += (self.optim_params.loss_decay_future_timesteps**k) * kl_bis
Expand All @@ -573,6 +575,7 @@ def train_step(self): # noqa: C901
h_sparsity = self.get_sparsity_violation(
lower_threshold=0.05, upper_threshold=self.optim_params.sparsity_upper_threshold
)
# sz: upper_threshold = 0.5: half of the edges are connected
sparsity_reg = self.ALM_sparsity.gamma * h_sparsity + 0.5 * self.ALM_sparsity.mu * h_sparsity**2
if self.optim_params.binarize_transition and h_sparsity == 0:
h_sparsity = self.adj_transition_variance()
Expand All @@ -593,7 +596,9 @@ def train_step(self): # noqa: C901

# compute total loss - here we are removing the sparsity regularisation as we are usings the constraint here.
loss = nll + connect_reg + sparsity_reg

if not self.no_w_constraint:

if self.constraint_func == "sum":
loss = (
loss + torch.sum(self.ALM_ortho.gamma @ h_ortho) + 0.5 * self.ALM_ortho.mu * torch.sum(h_ortho**2)
Expand All @@ -602,6 +607,7 @@ def train_step(self): # noqa: C901
loss = (
loss + torch.sum(self.ALM_ortho.gamma * h_ortho) + 0.5 * self.ALM_ortho.mu * torch.sum(h_ortho**2)
)

if self.instantaneous:
loss = loss + 0.5 * self.QPM_acyclic.mu * h_acyclic**2

Expand Down Expand Up @@ -645,6 +651,7 @@ def train_step(self): # noqa: C901
f"Scheduling spectrum coefficient at iterations {self.optim_params.scheduler_spectra} at coefficients {self.coefs_scheduler_spectra}"
)
print(f"Updating spectral coefficient to {coef} at iteration {self.iteration}!!")

loss = (
loss
+ self.optim_params.crps_coeff * crps
Expand All @@ -667,7 +674,7 @@ def train_step(self): # noqa: C901
self.optimizer.step() if self.optim_params.optimizer == "rmsprop" else self.optimizer.step()
), self.train_params.lr
# projection of the gradient for w
if self.model.autoencoder.use_grad_project and not self.no_w_constraint:
if not self.no_w_constraint:
with torch.no_grad():
self.model.autoencoder.get_w_decoder().clamp_(min=0.0)

Expand Down Expand Up @@ -1106,7 +1113,9 @@ def log_losses(self):
"""Append in lists values of the losses and more."""
# train
self.train_loss_list.append(-self.train_loss)
# sz: train_loss is positive number we want to minimize, -train_loss is negative
self.train_recons_list.append(self.train_recons)
# train_recons is negative number (logp)
self.train_kl_list.append(self.train_kl)

# here note that train_ortho_cons_list is a torch.sum...
Expand Down Expand Up @@ -1191,6 +1200,7 @@ def get_nll(self, x, y, z=None) -> torch.Tensor:

# this is just running the forward pass of LatentTSDCD...
elbo, recons, kl, preds = self.model(x, y, z, self.iteration)
# elbo=reconstrction-kl, maximize elbo-> minimize -elbo

return -elbo, recons, kl, preds

Expand Down Expand Up @@ -1225,9 +1235,15 @@ def get_ortho_violation(self, w: torch.Tensor) -> float:
# constraint = constraint + torch.norm(w[i].T @ w[i] - torch.eye(k), p=2)
i = 0
# constraint = torch.norm(w[i].T @ w[i] - torch.eye(k), p=2, dim=1)
# col_norms = torch.linalg.norm(w[i], axis=0)
# w_normalized = w[i] / col_norms
# constraint = w_normalized.T @ w_normalized - torch.eye(k)
constraint = w[i].T @ w[i] - torch.eye(k)
# print('What is the ortho constraint shape:', constraint.shape)
h = constraint / self.ortho_normalization

# mask = ~torch.eye(k, dtype=bool, device=w.device)
# h = h*mask
else:
h = torch.as_tensor([0.0])

Expand Down Expand Up @@ -1465,6 +1481,7 @@ def get_spatial_spectral_loss(self, y_true, y_pred, take_log=True):

assert y_true.dim() == 3
assert y_pred.dim() == 3
# print("y_pred shape", y_pred.shape)

if y_true.size(-1) == self.lat * self.lon:

Expand Down Expand Up @@ -1498,6 +1515,8 @@ def get_spatial_spectral_loss(self, y_true, y_pred, take_log=True):
fft_pred = torch.log(torch.abs(fft_pred) + 1e-4)

spectral_loss = torch.mean(torch.abs(fft_pred - fft_true), dim=0)
# print("spectral_loss shape", spectral_loss.shape)
# print("spectral_loss shape final", torch.mean(spectral_loss))
# spectral_loss = torch.mean(torch.nan_to_num(spectral_loss, 0), dim=0)

# Calculate the power spectrum
Expand Down
13 changes: 7 additions & 6 deletions climatem/model/tsdcd_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ class LatentTSDCD(nn.Module):

def __init__(
self,
num_layers: int,
num_hidden: int,
num_layers: int, # sz: transition model
num_hidden: int, # sz: transition model
num_input: int,
num_output: int,
num_layers_mixing: int,
num_layers_mixing: int, # sz: encoder/decoder
num_hidden_mixing: int,
position_embedding_dim: int,
transition_param_sharing: bool,
position_embedding_transition: int,
position_embedding_transition: int, # sz: 1NN per location, after sharing: 1NN for all locations
coeff_kl: float,
distr_z0: str,
distr_encoder: str,
Expand Down Expand Up @@ -520,6 +520,7 @@ def forward(self, x, y, gt_z, iteration, xi=None):
else:
px_distr = self.distr_decoder(px_mu, px_std)
recons = torch.mean(torch.sum(px_distr.log_prob(y), dim=[1, 2]))
# sz: log_theta p​(y∣z): 0<p<1, so log p < 0
# compute the KL, the reconstruction and the ELBO
# kl = distr.kl_divergence(q, p).mean()
kl_raw = (
Expand Down Expand Up @@ -920,7 +921,7 @@ def __init__(

def encode(self, x, i):

mask = super().get_encode_mask(x.shape[0])
mask = super().get_encode_mask()
mu = torch.zeros((x.shape[0], self.d_z), device=x.device)

j_values = torch.arange(self.d_z, device=x.device).expand(
Expand Down Expand Up @@ -950,7 +951,7 @@ def encode(self, x, i):

def decode(self, z, i):

mask = super().get_decode_mask(z.shape[0])
mask = super().get_decode_mask()
mu = torch.zeros((z.shape[0], self.d_x), device=z.device)

# Create a tensor of shape (z.shape[0], self.d_x) where each row is a sequence from 0 to self.d_x
Expand Down
12 changes: 12 additions & 0 deletions climatem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ def parse_args():
default="configs/param_file.json",
help="Path to a json file with values for all parameters",
)
parser.add_argument(
"--exp-id",
type=str,
default="var_ts",
help="experiment name for rollout",
)
parser.add_argument(
"--iter-id",
type=int,
default=200000,
help="model saving epoch",
)
# Add an argument for nested keys, this will be handled dynamically later
parser.add_argument("--hp", action="append", metavar="KEY=VALUE", help="Cmd line arguments")
return parser.parse_args()
Expand Down
10 changes: 5 additions & 5 deletions configs/single_param_file.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@
"train_params": {
"ratio_train": 0.9,
"lr": 0.0001,
"lr_scheduler_epochs": [10000, 25000, 50000],
"lr_scheduler_epochs": [10000,25000,50000],
"lr_scheduler_gamma": 1,
"max_iteration": 200000,
"patience": 5000,
"patience_post_thresh": 50,
"valid_freq": 100
"valid_freq": 200
},
"model_params": {
"instantaneous": false,
Expand All @@ -69,12 +69,12 @@
"num_hidden_mixing": 16,
"num_layers_mixing": 2,
"nonlinear_dynamics": true,
"num_hidden": 16,
"num_hidden": 8,
"num_layers": 2,
"num_output": 2,
"position_embedding_dim": 100,
"transition_param_sharing": true,
"position_embedding_transition": 100,
"position_embedding_transition": 60,
"fixed": false,
"fixed_output_fraction": null,
"constraint_func": "trace"
Expand All @@ -85,7 +85,7 @@
"use_sparsity_constraint": true,
"binarize_transition": true,
"crps_coeff": 1,
"spectral_coeff": 1000,
"spectral_coeff": 2000,
"temporal_spectral_coeff": 2000,
"coeff_kl": 1,

Expand Down
12 changes: 6 additions & 6 deletions configs/single_param_file_new.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"exp_params": {
"exp_path": "$SCRATCH/results/latest_runs/",
"exp_path": "$SCRATCH/results/test_debug_small/",
"_target_": "emulator.src.datamodules.climate_datamodule.ClimateDataModule",
"latent": true,
"d_z": 90,
Expand Down Expand Up @@ -59,7 +59,7 @@
"max_iteration": 200000,
"patience": 5000,
"patience_post_thresh": 50,
"valid_freq": 100
"valid_freq": 200
},
"model_params": {
"instantaneous": false,
Expand All @@ -83,7 +83,7 @@
"optimizer": "rmsprop",

"use_sparsity_constraint": true,
"binarize_transition": true,
"binarize_transition": false,
"crps_coeff": 1,
"spectral_coeff": 1000,
"temporal_spectral_coeff": 2000,
Expand All @@ -92,9 +92,9 @@
"loss_decay_future_timesteps": 1,

"fraction_highest_wavenumbers": 0.5,
"fraction_lowest_wavenumbers": null,
"fraction_lowest_wavenumbers": 0.95,
"take_log_spectra": true,
"scheduler_spectra": [100000],
"scheduler_spectra": null,

"reg_coeff": 0.12801,
"reg_coeff_connect": 0,
Expand Down Expand Up @@ -134,7 +134,7 @@
"plot_params": {
"plot_freq": 20000,
"plot_through_time": true,
"print_freq": 1000
"print_freq": 20000
},
"savar_params": {
"n_per_col": 2,
Expand Down
13 changes: 7 additions & 6 deletions configs/single_param_file_savar.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"exp_params": {
"exp_path": "$SCRATCH/results/savar_data_new",
"exp_path": "$SCRATCH/results/SAVAR_DATA_TEST_True",
"_target_": "emulator.src.datamodules.climate_datamodule.ClimateDataModule",
"latent": true,
"d_z": 4,
Expand All @@ -17,8 +17,8 @@
},
"data_params": {
"data_dir": "$SCRATCH/data/SAVAR_DATA_TEST",
"climateset_data": "$SCRATCH/data/icml_processed_data/picontrol/24_ni",
"reload_climate_set_data": false,
"climateset_data": "/network/scratch/j/julien.boussard/data/icml_processed_data/picontrol/24_ni",
"reload_climate_set_data": true,
"icosahedral_coordinates_path": "$CLIMATEMDIR/mappings/vertex_lonlat_mapping.npy",
"in_var_ids": ["savar"],
"out_var_ids": ["savar"],
Expand Down Expand Up @@ -52,7 +52,7 @@
"max_iteration": 100000,
"patience": 5000,
"patience_post_thresh": 50,
"valid_freq": 100
"valid_freq": 200
},
"model_params": {
"instantaneous": false,
Expand Down Expand Up @@ -106,7 +106,7 @@
"sparsity_omega_mu": 0.95,
"sparsity_h_threshold": 1e-4,
"sparsity_min_iter_convergence": 1000,
"sparsity_upper_threshold": 0.1,
"sparsity_upper_threshold": 0.05,

"acyclic_mu_init": 1,
"acyclic_mu_mult_factor": 2,
Expand All @@ -121,9 +121,10 @@
"udpate_ALM_using_nll": false
},
"plot_params": {
"plot_freq": 10000,
"plot_freq": 50000,
"plot_through_time": true,
"print_freq": 1000

},
"savar_params": {
"n_per_col": 2,
Expand Down
Loading