Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clamping introduces nan if ar_steps_train > 1 #119

Closed
sadamov opened this issue Feb 10, 2025 · 8 comments · Fixed by #123
Closed

Clamping introduces nan if ar_steps_train > 1 #119

sadamov opened this issue Feb 10, 2025 · 8 comments · Fixed by #123
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed
Milestone

Comments

@sadamov
Copy link
Collaborator

sadamov commented Feb 10, 2025

After a few iterations clamping state variables introduces nan in the train loss. Replacing the sigmoid and inverse sigmoid functions with simple torch.clamp does prevent the issue (just as an indication). Here is a reproducible example based on the danra datastore from the test_examples:

  • First build the datastore.zarr and the graph (I built hierarchical archetype)
  • Then start the training with neural_lam
             python -m neural_lam.train_model \
            --config_path ./neural-lam/tests/datastore_examples/mdp/danra_100m_winds/config.yaml \
            --model hi_lam \
            --graph_name hierarchical \
            --hidden_dim 16 \
            --hidden_dim_grid 16 \
            --time_delta_enc_dim 16 \
            --processor_layers 2 \
            --batch_size 1 \
            --epochs 1 \
            --ar_steps_train 2 \
            --ar_steps_eval 2 \
            --val_steps_to_log 1 2
  • Observe the introduction of nans
Training: |          | 0/? [00:00<?, ?it/s]
Training:   0%|          | 0/20 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/20 [00:00<?, ?it/s] 
Epoch 0:   5%|| 1/20 [00:02<00:49,  0.39it/s]
Epoch 0:   5%|| 1/20 [00:02<00:49,  0.39it/s, v_num=pbts, train_loss_step=4.740]
Epoch 0:  10%|| 2/20 [00:03<00:27,  0.65it/s, v_num=pbts, train_loss_step=4.740]
Epoch 0:  10%|| 2/20 [00:03<00:27,  0.65it/s, v_num=pbts, train_loss_step=nan.0]
Epoch 0:  15%|█▌        | 3/20 [00:03<00:20,  0.81it/s, v_num=pbts, train_loss_step=nan.0]
Epoch 0:  15%|█▌        | 3/20 [00:03<00:21,  0.80it/s, v_num=pbts, train_loss_step=nan.0]
@sadamov sadamov added bug Something isn't working help wanted Extra attention is needed labels Feb 10, 2025
@sadamov sadamov added this to the v0.4.0 milestone Feb 10, 2025
@joeloskarsson
Copy link
Collaborator

Just so it does not cause confusion: The options

            --hidden_dim_grid 16 \
            --time_delta_enc_dim 16 \

are not available on the main branch, but should have nothing to do with this.

@joeloskarsson
Copy link
Collaborator

joeloskarsson commented Feb 11, 2025

I'm on a train and didn't have the DANRA data downloaded, so tried with MEPS data and this config:

python -m neural_lam.train_model \
            --config_path tests/datastore_examples/npyfilesmeps/meps_example_reduced/config.yaml \
            --model graph_lam \
            --graph 1level \
            --hidden_dim 4 \
            --processor_layers 1 \
            --batch_size 1 \
            --epochs 50 \
            --ar_steps_train 2 \
            --ar_steps_eval 2 \
            --val_steps_to_log 1 2

Interestingly this does not give me any nan train_loss_step. That is of course because there is no clamping applied in that config 🤦

@sadamov
Copy link
Collaborator Author

sadamov commented Feb 11, 2025

I don't think the config file you are referring to exists on main. Did you base it on the README.md?
But that one doesn't have clamping enabled.

# config.yaml
datastore:
  kind: npyfilesmeps
  config_path: meps.datastore.yaml
training:
  state_feature_weighting:
    __config_class__: ManualStateFeatureWeighting
    values:
      u100m: 1.0
      v100m: 1.0

@joeloskarsson
Copy link
Collaborator

Oh, yes you are correct. Don't mind my comment above.

@joeloskarsson
Copy link
Collaborator

Hmm, this is quite confusing. After 1 batch there are non NaNs in the state resulting from the clamping output by

new_state = self.get_clamped_new_state(rescaled_delta_mean, prev_state)

However, on iteration 2 already
net_output = self.output_map(

is only NaNs.

This would imply that there are some NaN weights in the network in the second iteration. That will then result in NaNs in the output state. Could it be that the first iteration does not create NaNs in the state, but the gradients w.r.t. something are NaN or inf?

@joeloskarsson
Copy link
Collaborator

Running with Trainer(detect_anomaly=True, ...) yields

RuntimeError: Function 'Expm1Backward0' returned nan values in its 0th output.

So it first found nan in a backwards pass (of Expm1, whatever that is). That does not neccesarily mean that the nan originated from there, but could be.

@joeloskarsson
Copy link
Collaborator

Oh, right, that is probably

torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta

in utils.inverse_softplus

@joeloskarsson
Copy link
Collaborator

joeloskarsson commented Feb 11, 2025

torch.expm1 seems to be the problem. If I remove just that function call the nan:s disappear. And strangely it seems like neither its input nor output contains any nan:s, only its gradient (?).

I tried adjusting the threshold parameter used in inverse_softplus, but it did not change anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants