|
| 1 | +from typing import Dict, List, Tuple |
| 2 | +from dataclasses import dataclass |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from nequip.train import Trainer, Loss |
| 6 | + |
| 7 | +# Making this a dataclass takes care of equality operators, handing restart consistency checks |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class SimpleLossSchedule: |
| 12 | + """Schedule `loss_coeffs` through a training run. |
| 13 | +
|
| 14 | + To use this in a training, set in your YAML file: |
| 15 | +
|
| 16 | + start_of_epoch_callbacks: |
| 17 | + - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[30, {"forces": 1.0, "total_energy": 0.0}], [30, {"forces": 0.0, "total_energy": 1.0}]]} |
| 18 | +
|
| 19 | + This funny syntax tells PyYAML to construct an object of this class. |
| 20 | +
|
| 21 | + Each entry in the schedule is a tuple of the 1-based epoch index to start that loss coefficient set at, and a dict of loss coefficients. |
| 22 | + """ |
| 23 | + |
| 24 | + schedule: List[Tuple[int, Dict[str, float]]] = None |
| 25 | + |
| 26 | + def __call__(self, trainer: Trainer): |
| 27 | + assert ( |
| 28 | + self in trainer._start_of_epoch_callbacks |
| 29 | + ), "must be start not end of epoch" |
| 30 | + # user-facing 1 based indexing of epochs rather than internal zero based |
| 31 | + iepoch: int = trainer.iepoch + 1 |
| 32 | + if iepoch < 1: # initial validation epoch is 0 in user-facing indexing |
| 33 | + return |
| 34 | + loss_function: Loss = trainer.loss |
| 35 | + |
| 36 | + assert self.schedule is not None |
| 37 | + schedule_start_epochs = np.asarray([e[0] for e in self.schedule]) |
| 38 | + # make sure they are ascending |
| 39 | + assert len(schedule_start_epochs) >= 1 |
| 40 | + assert schedule_start_epochs[0] >= 2, "schedule must start at epoch 2 or later" |
| 41 | + assert np.all( |
| 42 | + (schedule_start_epochs[1:] - schedule_start_epochs[:-1]) > 0 |
| 43 | + ), "schedule start epochs must be strictly ascending" |
| 44 | + # we are running at _start_ of epoch, so we need to apply the right change for the current epoch |
| 45 | + current_change_idex = np.searchsorted(schedule_start_epochs, iepoch + 1) - 1 |
| 46 | + # ^ searchsorted 3 in [2, 10, 19] would return 1, for example |
| 47 | + # but searching 2 in [2, 10, 19] gives 0, so we actually search iepoch + 1 to always be ahead of the start |
| 48 | + # apply the current change to handle restarts |
| 49 | + if current_change_idex >= 0: |
| 50 | + new_coeffs = self.schedule[current_change_idex][1] |
| 51 | + assert ( |
| 52 | + loss_function.coeffs.keys() == new_coeffs.keys() |
| 53 | + ), "all coeff schedules must contain all loss terms" |
| 54 | + loss_function.coeffs.update(new_coeffs) |
0 commit comments