diff --git a/CHANGELOG.md b/CHANGELOG.md index e1c228d05..556e70ab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## Unreleased +## Unreleased - 0.7.0 +### Added +- `--override` now supported as a `nequip-train` flag (similar to its use in `nequip-deploy`) +- add SoftAdapt (https://arxiv.org/abs/2403.18122) callback option + +### Changed +- [Breaking] training restart behavior altered: file-wise consistency checks performed between original config and config passed to `nequip-train` on restart (instead of checking the config dicts) +- [Breaking] config format for callbacks changed (see `configs/full.yaml` for an example) +### Fixed +- fixed `wandb_watch` bug ## [0.6.1] - 2024-7-9 ### Added diff --git a/configs/full.yaml b/configs/full.yaml index 765d90d8a..91684a36d 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -256,10 +256,17 @@ loss_coeffs: # In the "schedule" key each entry is a two-element list of: # - the 1-based epoch index at which to start the new loss coefficients # - the new loss coefficients as a dict -# -# start_of_epoch_callbacks: -# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]} -# +# callbacks: +# start_of_epoch: +# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]} + +# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients +# (see https://arxiv.org/abs/2403.18122) +#callbacks: +# end_of_batch: +# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1} + + # output metrics metrics_components: diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index e83bd299a..d91aad3cc 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -3,6 +3,9 @@ import logging import argparse import warnings +import shutil +import difflib +import yaml # This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch. # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. @@ -29,6 +32,8 @@ root="./", tensorboard=False, wandb=False, + wandb_watch=False, + wandb_watch_kwargs={}, model_builders=[ "SimpleIrrepsConfig", "EnergyModel", @@ -46,7 +51,7 @@ equivariance_test=False, grad_anomaly_mode=False, gpu_oom_offload=False, - append=False, + append=True, warn_unused=False, _jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286 # Quote from eelison in PyTorch slack: @@ -68,32 +73,61 @@ def main(args=None, running_as_script: bool = True): - config = parse_command_line(args) + config, path_to_config, override_options = parse_command_line(args) if running_as_script: set_up_script_logger(config.get("log", None), config.verbose) - found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth") + train_dir = f"{config.root}/{config.run_name}" + found_restart_file = exists(f"{train_dir}/trainer.pth") if found_restart_file and not config.append: raise RuntimeError( - f"Training instance exists at {config.root}/{config.run_name}; " + f"Training instance exists at {train_dir}; " "either set append to True or use a different root or runname" ) - elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"): + elif not found_restart_file and isdir(train_dir): # output directory exists but no ``trainer.pth`` file, suggesting previous run crash during # first training epoch (usually due to memory): warnings.warn( - f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model " + f"Previous run folder at {train_dir} exists, but a saved model " f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will " f"be started." ) - rmtree(f"{config.root}/{config.run_name}") + rmtree(train_dir) - # for fresh new train - if not found_restart_file: + if not found_restart_file: # fresh start + # update config with override parameters for setting up train-dir + config.update(override_options) trainer = fresh_start(config) - else: - trainer = restart(config) + # copy original config to training directory + shutil.copyfile(path_to_config, f"{train_dir}/original_config.yaml") + else: # restart + # perform string matching for original config and restart config + # throw error if they are different + with ( + open(f"{train_dir}/original_config.yaml") as orig_f, + open(path_to_config) as current_f, + ): + diffs = [ + x + for x in difflib.Differ().compare( + orig_f.readlines(), current_f.readlines() + ) + if x[0] in ("+", "-") + ] + if diffs: + raise RuntimeError( + f"Config {path_to_config} used for restart differs from original config for training run in {train_dir}.\n" + + "The following differences were found:\n\n" + + "".join(diffs) + + "\n" + + "If you intend to override the original config parameters, use the --override flag. For example, use\n" + + f'`nequip-train {path_to_config} --override "max_epochs: 42"`\n' + + 'on the command line to override the config parameter "max_epochs"\n' + + "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP." + ) + else: + trainer = restart(config, override_options) # Train trainer.save() @@ -157,6 +191,12 @@ def parse_command_line(args=None): help="Warn instead of error when the config contains unused keys", action="store_true", ) + parser.add_argument( + "--override", + help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.", + type=str, + default=None, + ) args = parser.parse_args(args=args) config = Config.from_file(args.config, defaults=default_config) @@ -169,10 +209,26 @@ def parse_command_line(args=None): ): config[flag] = getattr(args, flag) or config[flag] - return config + # Set override options before _set_global_options so that things like allow_tf32 are correctly handled + if args.override is not None: + override_options = yaml.load(args.override, Loader=yaml.Loader) + assert isinstance( + override_options, dict + ), "--override's YAML string must define a dictionary of top-level options" + overridden_keys = set(config.keys()).intersection(override_options.keys()) + set_keys = set(override_options.keys()) - set(overridden_keys) + logging.info( + f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}" + ) + del overridden_keys, set_keys + else: + override_options = {} + + return config, args.config, override_options def fresh_start(config): + # we use add_to_config cause it's a fresh start and need to record it check_code_version(config, add_to_config=True) _set_global_options(config) @@ -267,7 +323,7 @@ def _unused_check(): return trainer -def restart(config): +def restart(config, override_options): # load the dictionary restart_file = f"{config.root}/{config.run_name}/trainer.pth" dictionary = load_file( @@ -276,20 +332,6 @@ def restart(config): enforced_format="torch", ) - # compare dictionary to config and update stop condition related arguments - for k in config.keys(): - if config[k] != dictionary.get(k, ""): - if k == "max_epochs": - dictionary[k] = config[k] - logging.info(f'Update "{k}" to {dictionary[k]}') - elif k.startswith("early_stop"): - dictionary[k] = config[k] - logging.info(f'Update "{k}" to {dictionary[k]}') - elif isinstance(config[k], type(dictionary.get(k, ""))): - raise ValueError( - f'Key "{k}" is different in config and the result trainer.pth file. Please double check' - ) - # note, "trainer.pth"/dictionary also store code versions, # which will not be stored in config and thus not checked here check_code_version(config) @@ -299,6 +341,10 @@ def restart(config): config = Config(dictionary, exclude_keys=["state_dict", "progress"]) + # override configs loaded from save + dictionary.update(override_options) + config.update(override_options) + # dtype, etc. _set_global_options(config) diff --git a/nequip/train/callback_manager.py b/nequip/train/callback_manager.py new file mode 100644 index 000000000..e5cfaf747 --- /dev/null +++ b/nequip/train/callback_manager.py @@ -0,0 +1,49 @@ +from nequip.utils import load_callable +import dataclasses + + +class CallbackManager: + """Parent callback class + + Centralized object to manage various callbacks that can be added-on. + """ + + def __init__( + self, + callbacks={}, + ): + CALLBACK_TYPES = [ + "init", + "start_of_epoch", + "end_of_epoch", + "end_of_batch", + "end_of_train", + "final", + ] + # load all callbacks + self.callbacks = {callback_type: [] for callback_type in CALLBACK_TYPES} + + for callback_type in callbacks: + if callback_type not in CALLBACK_TYPES: + raise ValueError( + f"{callback_type} is not a supported callback type.\nSupported callback types include " + + str(CALLBACK_TYPES) + ) + # make sure callbacks are either dataclasses or functions + for callback in callbacks[callback_type]: + if not (dataclasses.is_dataclass(callback) or callable(callback)): + raise ValueError( + f"Callbacks must be of type dataclass or callable. Error found on the callback {callback} of type {callback_type}" + ) + self.callbacks[callback_type].append(load_callable(callback)) + + def apply(self, trainer, callback_type: str): + + for callback in self.callbacks.get(callback_type): + callback(trainer) + + def state_dict(self): + return {"callback_manager_obj_callbacks": self.callbacks} + + def load_state_dict(self, state_dict): + self.callbacks = state_dict.get("callback_manager_obj_callbacks") diff --git a/nequip/train/callbacks/__init__.py b/nequip/train/callbacks/__init__.py new file mode 100644 index 000000000..8e7f70cf2 --- /dev/null +++ b/nequip/train/callbacks/__init__.py @@ -0,0 +1,4 @@ +from .adaptive_loss_weights import SoftAdapt +from .loss_schedule import SimpleLossSchedule + +__all__ = [SoftAdapt, SimpleLossSchedule] diff --git a/nequip/train/callbacks/adaptive_loss_weights.py b/nequip/train/callbacks/adaptive_loss_weights.py new file mode 100644 index 000000000..1e0f59476 --- /dev/null +++ b/nequip/train/callbacks/adaptive_loss_weights.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass + +from nequip.train import Trainer + +from nequip.train._key import ABBREV +import torch + +# Making this a dataclass takes care of equality operators, handing restart consistency checks + + +@dataclass +class SoftAdapt: + """Adaptively modify `loss_coeffs` through a training run using the SoftAdapt scheme (https://arxiv.org/abs/2403.18122) + + To use this in a training, set in your YAML file: + + end_of_batch_callbacks: + - !!python/object:nequip.train.callbacks.adaptive_loss_weights.SoftAdapt {"batches_per_update": 20, "beta": 1.0} + + This funny syntax tells PyYAML to construct an object of this class. + + Main hyperparameters are: + - how often the loss weights are updated, `batches_per_update` + - how sensitive the new loss weights are to the change in loss components, `beta` + """ + + # user-facing parameters + batches_per_update: int = None + beta: float = None + eps: float = 1e-8 # small epsilon to avoid division by zero + # attributes for internal tracking + batch_counter: int = -1 + prev_losses: torch.Tensor = None + cached_weights = None + + def __call__(self, trainer: Trainer): + + # --- CORRECTNESS CHECKS --- + assert self in trainer.callback_manager.callbacks["end_of_batch"] + assert self.batches_per_update >= 1 + + # track batch number + self.batch_counter += 1 + + # empty list of cached weights to store for next cycle + if self.batch_counter % self.batches_per_update == 0: + self.cached_weights = [] + + # --- MAIN LOGIC THAT RUNS EVERY EPOCH --- + + # collect loss for each training target + losses = [] + for key in trainer.loss.coeffs.keys(): + losses.append(trainer.batch_losses[f"loss_{ABBREV.get(key)}"]) + new_losses = torch.tensor(losses) + + # compute and cache new loss weights over the update cycle + if self.prev_losses is None: + self.prev_losses = new_losses + return + else: + # compute normalized loss change + loss_change = new_losses - self.prev_losses + loss_change = torch.nn.functional.normalize( + loss_change, dim=0, eps=self.eps + ) + self.prev_losses = new_losses + # compute new weights with softmax + exps = torch.exp(self.beta * loss_change) + self.cached_weights.append(exps.div(exps.sum() + self.eps)) + + # --- average weights over previous cycle and update --- + if self.batch_counter % self.batches_per_update == 1: + softadapt_weights = torch.stack(self.cached_weights, dim=-1).mean(-1) + counter = 0 + for key in trainer.loss.coeffs.keys(): + trainer.loss.coeffs.update({key: softadapt_weights[counter]}) + counter += 1 diff --git a/nequip/train/callbacks/loss_schedule.py b/nequip/train/callbacks/loss_schedule.py index edd6f173a..666c1e66e 100644 --- a/nequip/train/callbacks/loss_schedule.py +++ b/nequip/train/callbacks/loss_schedule.py @@ -25,9 +25,10 @@ class SimpleLossSchedule: def __call__(self, trainer: Trainer): assert ( - self in trainer._start_of_epoch_callbacks + self in trainer.callback_manager.callbacks["start_of_epoch"] ), "must be start not end of epoch" # user-facing 1 based indexing of epochs rather than internal zero based + iepoch: int = trainer.iepoch + 1 if iepoch < 1: # initial validation epoch is 0 in user-facing indexing return diff --git a/nequip/train/loss.py b/nequip/train/loss.py index fe5144dac..9cd85d6c3 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -109,6 +109,15 @@ def __call__(self, pred: dict, ref: dict): return loss, contrib + def state_dict(self): + # verbose key names to avoid repetition/clashes + dictionary = {"loss_obj_coeffs": self.coeffs} + return dictionary + + def load_state_dict(self, state_dict): + # only need to save/load loss weights (or coefficients) + self.coeffs = state_dict.get("loss_obj_coeffs") + class LossStat: """ diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 2c257785c..381c5bd15 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -42,7 +42,6 @@ instantiate, save_file, load_file, - load_callable, atomic_write, finish_all_writes, atomic_write_group, @@ -56,6 +55,7 @@ from .metrics import Metrics from ._key import ABBREV, LOSS_KEY, TRAIN, VALIDATION from .early_stopping import EarlyStopping +from .callback_manager import CallbackManager class Trainer: @@ -217,7 +217,14 @@ class Trainer: """ stop_keys = ["max_epochs", "early_stopping", "early_stopping_kwargs"] - object_keys = ["lr_sched", "optim", "ema", "early_stopping_conds"] + object_keys = [ + "loss", + "callback_manager", + "lr_sched", + "optim", + "ema", + "early_stopping_conds", + ] lr_scheduler_module = torch.optim.lr_scheduler optim_module = torch.optim @@ -258,12 +265,6 @@ def __init__( train_idcs: Optional[list] = None, val_idcs: Optional[list] = None, train_val_split: str = "random", - init_callbacks: list = [], - start_of_epoch_callbacks: list = [], - end_of_epoch_callbacks: list = [], - end_of_batch_callbacks: list = [], - end_of_train_callbacks: list = [], - final_callbacks: list = [], log_batch_freq: int = 100, log_epoch_freq: int = 1, save_checkpoint_freq: int = -1, @@ -343,29 +344,18 @@ def __init__( ) self.loss_stat = LossStat(self.loss) + # initialize callback manager + self.callback_manager, _ = instantiate( + builder=CallbackManager, + prefix="callbacks", + all_args=self.kwargs, + ) + # what do we train on? self.train_on_keys = self.loss.keys if train_on_keys is not None: assert set(train_on_keys) == set(self.train_on_keys) - # load all callbacks - self._init_callbacks = [load_callable(callback) for callback in init_callbacks] - self._start_of_epoch_callbacks = [ - load_callable(callback) for callback in start_of_epoch_callbacks - ] - self._end_of_epoch_callbacks = [ - load_callable(callback) for callback in end_of_epoch_callbacks - ] - self._end_of_batch_callbacks = [ - load_callable(callback) for callback in end_of_batch_callbacks - ] - self._end_of_train_callbacks = [ - load_callable(callback) for callback in end_of_train_callbacks - ] - self._final_callbacks = [ - load_callable(callback) for callback in final_callbacks - ] - self.init() def init_objects(self): @@ -761,8 +751,7 @@ def train(self): if not self._initialized: self.init() - for callback in self._init_callbacks: - callback(self) + self.callback_manager.apply(self, "init") self.init_log() self.wall = perf_counter() @@ -784,8 +773,7 @@ def train(self): self.epoch_step() self.end_of_epoch_save() - for callback in self._final_callbacks: - callback(self) + self.callback_manager.apply(self, "final") self.final_log() @@ -891,8 +879,7 @@ def reset_metrics(self): self.metrics.to(self.torch_device) def epoch_step(self): - for callback in self._start_of_epoch_callbacks: - callback(self) + self.callback_manager.apply(self, "start_of_epoch") dataloaders = {TRAIN: self.dl_train, VALIDATION: self.dl_val} categories = [TRAIN, VALIDATION] if self.iepoch >= 0 else [VALIDATION] @@ -921,14 +908,12 @@ def epoch_step(self): validation=(category == VALIDATION), ) self.end_of_batch_log(batch_type=category) - for callback in self._end_of_batch_callbacks: - callback(self) + self.callback_manager.apply(self, "end_of_batch") self.metrics_dict[category] = self.metrics.current_result() self.loss_dict[category] = self.loss_stat.current_result() if category == TRAIN: - for callback in self._end_of_train_callbacks: - callback(self) + self.callback_manager.apply(self, "end_of_train") self.iepoch += 1 @@ -941,8 +926,7 @@ def epoch_step(self): if self.iepoch > 0 and self.lr_scheduler_name == "ReduceLROnPlateau": self.lr_sched.step(metrics=self.mae_dict[self.metrics_key]) - for callback in self._end_of_epoch_callbacks: - callback(self) + self.callback_manager.apply(self, "end_of_epoch") def end_of_batch_log(self, batch_type: str): """ diff --git a/nequip/utils/_global_options.py b/nequip/utils/_global_options.py index 3a08e55ec..711c663a2 100644 --- a/nequip/utils/_global_options.py +++ b/nequip/utils/_global_options.py @@ -92,9 +92,10 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: # k = "PYTORCH_JIT_USE_NNC_NOT_NVFUSER" if k in os.environ: - warnings.warn( - "Do NOT manually set PYTORCH_JIT_USE_NNC_NOT_NVFUSER=0 unless you know exactly what you're doing!" - ) + if os.environ[k] != "1": + warnings.warn( + "Do NOT manually set PYTORCH_JIT_USE_NNC_NOT_NVFUSER=0 unless you know exactly what you're doing!" + ) else: os.environ[k] = "1" diff --git a/tests/integration/test_train.py b/tests/integration/test_train.py index b9935b3c2..889f366fb 100644 --- a/tests/integration/test_train.py +++ b/tests/integration/test_train.py @@ -148,7 +148,6 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): for irun in range(3): - true_config["max_epochs"] = 2 * (irun + 1) config_path = tmpdir + "/conf.yaml" with open(config_path, "w+") as fp: yaml.dump(true_config, fp) @@ -162,7 +161,13 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): retcode = subprocess.run( # Supress the warning cause we use general config for all the fake models - ["nequip-train", "conf.yaml", "--warn-unused"], + [ + "nequip-train", + "conf.yaml", + "--warn-unused", + "--override", + f"max_epochs: {2 * (irun + 1)}", + ], cwd=tmpdir, env=env, stdout=subprocess.PIPE, @@ -178,4 +183,4 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): dtype=None, ) - assert len(dat["epoch"]) == true_config["max_epochs"] + assert len(dat["epoch"]) == (2 * (irun + 1))