Skip to content

Commit 9ba1d5f

Browse files
Add SimpleLossSchedule
1 parent 7fcd45d commit 9ba1d5f

File tree

4 files changed

+73
-5
lines changed

4 files changed

+73
-5
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Most recent change on the bottom.
2222
- `include_file_as_baseline_config` for simple modifications of existing configs
2323
- `nequip-deploy --using-dataset` to support data-dependent deployment steps
2424
- Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574)
25+
- `start_of_epoch_callbacks`
26+
- `nequip.train.callbacks.loss_schedule.SimpleLossSchedule` for changing the loss coefficients at specified epochs
2527
- `nequip-deploy build --checkpoint` and `--override` to avoid many largely duplicated YAML files
2628
- matscipy neighborlist support enabled with `NEQUIP_MATSCIPY_NL` environment variable
2729

configs/full.yaml

+11-5
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ early_stopping_upper_bounds:
212212

213213
# loss function
214214
loss_coeffs: # different weights to use in a weighted loss functions
215-
forces: 1 # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
215+
forces: 1.0 # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
216216
total_energy:
217-
- 1
217+
- 1.0
218218
- PerAtomMSELoss
219219
# note that the ratio between force and energy loss matters for the training process. One may consider using 1:1 with the PerAtomMSELoss. If the energy loss still significantly dominate the loss function at the initial epochs, tune the energy loss weight lower helps the training a lot.
220220

@@ -249,6 +249,15 @@ loss_coeffs:
249249
# - L1Loss
250250
# forces: 1.0
251251

252+
# You can schedule changes in the loss coefficients using a callback:
253+
# In the "schedule" key each entry is a two-element list of:
254+
# - the 1-based epoch index at which to start the new loss coefficients
255+
# - the new loss coefficients as a dict
256+
#
257+
# start_of_epoch_callbacks:
258+
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
259+
#
260+
252261
# output metrics
253262
metrics_components:
254263
- - forces # key
@@ -371,6 +380,3 @@ global_rescale_scale_trainable: false
371380
# per_species_rescale_shifts: null
372381
# per_species_rescale_scales: null
373382

374-
# Options for e3nn's set_optimization_defaults. A dict:
375-
# e3nn_optimization_defaults:
376-
# explicit_backward: True
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Dict, List, Tuple
2+
from dataclasses import dataclass
3+
import numpy as np
4+
5+
from nequip.train import Trainer, Loss
6+
7+
# Making this a dataclass takes care of equality operators, handing restart consistency checks
8+
9+
10+
@dataclass
11+
class SimpleLossSchedule:
12+
"""Schedule `loss_coeffs` through a training run.
13+
14+
To use this in a training, set in your YAML file:
15+
16+
start_of_epoch_callbacks:
17+
- !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[30, {"forces": 1.0, "total_energy": 0.0}], [30, {"forces": 0.0, "total_energy": 1.0}]]}
18+
19+
This funny syntax tells PyYAML to construct an object of this class.
20+
21+
Each entry in the schedule is a tuple of the 1-based epoch index to start that loss coefficient set at, and a dict of loss coefficients.
22+
"""
23+
24+
schedule: List[Tuple[int, Dict[str, float]]] = None
25+
26+
def __call__(self, trainer: Trainer):
27+
assert (
28+
self in trainer._start_of_epoch_callbacks
29+
), "must be start not end of epoch"
30+
# user-facing 1 based indexing of epochs rather than internal zero based
31+
iepoch: int = trainer.iepoch + 1
32+
if iepoch < 1: # initial validation epoch is 0 in user-facing indexing
33+
return
34+
loss_function: Loss = trainer.loss
35+
36+
assert self.schedule is not None
37+
schedule_start_epochs = np.asarray([e[0] for e in self.schedule])
38+
# make sure they are ascending
39+
assert len(schedule_start_epochs) >= 1
40+
assert schedule_start_epochs[0] >= 2, "schedule must start at epoch 2 or later"
41+
assert np.all(
42+
(schedule_start_epochs[1:] - schedule_start_epochs[:-1]) > 0
43+
), "schedule start epochs must be strictly ascending"
44+
# we are running at _start_ of epoch, so we need to apply the right change for the current epoch
45+
current_change_idex = np.searchsorted(schedule_start_epochs, iepoch + 1) - 1
46+
# ^ searchsorted 3 in [2, 10, 19] would return 1, for example
47+
# but searching 2 in [2, 10, 19] gives 0, so we actually search iepoch + 1 to always be ahead of the start
48+
# apply the current change to handle restarts
49+
if current_change_idex >= 0:
50+
new_coeffs = self.schedule[current_change_idex][1]
51+
assert (
52+
loss_function.coeffs.keys() == new_coeffs.keys()
53+
), "all coeff schedules must contain all loss terms"
54+
loss_function.coeffs.update(new_coeffs)

nequip/train/trainer.py

+6
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
val_idcs: Optional[list] = None,
259259
train_val_split: str = "random",
260260
init_callbacks: list = [],
261+
start_of_epoch_callbacks: list = [],
261262
end_of_epoch_callbacks: list = [],
262263
end_of_batch_callbacks: list = [],
263264
end_of_train_callbacks: list = [],
@@ -348,6 +349,9 @@ def __init__(
348349

349350
# load all callbacks
350351
self._init_callbacks = [load_callable(callback) for callback in init_callbacks]
352+
self._start_of_epoch_callbacks = [
353+
load_callable(callback) for callback in start_of_epoch_callbacks
354+
]
351355
self._end_of_epoch_callbacks = [
352356
load_callable(callback) for callback in end_of_epoch_callbacks
353357
]
@@ -887,6 +891,8 @@ def reset_metrics(self):
887891
self.metrics.to(self.torch_device)
888892

889893
def epoch_step(self):
894+
for callback in self._start_of_epoch_callbacks:
895+
callback(self)
890896

891897
dataloaders = {TRAIN: self.dl_train, VALIDATION: self.dl_val}
892898
categories = [TRAIN, VALIDATION] if self.iepoch >= 0 else [VALIDATION]

0 commit comments

Comments
 (0)