Skip to content

Commit

Permalink
streamline callback_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Jul 9, 2024
1 parent 3e64aa0 commit 9f8fdf4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
12 changes: 7 additions & 5 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,16 @@ loss_coeffs:
# In the "schedule" key each entry is a two-element list of:
# - the 1-based epoch index at which to start the new loss coefficients
# - the new loss coefficients as a dict
#
# start_of_epoch_callbacks:
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
# callbacks:
# start_of_epoch:
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}

# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients
# (see https://arxiv.org/abs/2403.18122)
# end_of_batch_callbacks:
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.0}
#callbacks:
# end_of_batch:
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1}



# output metrics
Expand Down
51 changes: 24 additions & 27 deletions nequip/train/callback_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,32 @@ class CallbackManager:

def __init__(
self,
init_callbacks=[],
start_of_epoch_callbacks=[],
end_of_epoch_callbacks=[],
end_of_batch_callbacks=[],
end_of_train_callbacks=[],
final_callbacks=[],
callbacks={},
):

CALLBACK_TYPES = [
"init",
"start_of_epoch",
"end_of_epoch",
"end_of_batch",
"end_of_train",
"final",
]
# load all callbacks
self.callbacks = {
"init": [load_callable(callback) for callback in init_callbacks],
"start_of_epoch": [
load_callable(callback) for callback in start_of_epoch_callbacks
],
"end_of_epoch": [
load_callable(callback) for callback in end_of_epoch_callbacks
],
"end_of_batch": [
load_callable(callback) for callback in end_of_batch_callbacks
],
"end_of_train": [
load_callable(callback) for callback in end_of_train_callbacks
],
"final": [load_callable(callback) for callback in final_callbacks],
}

for callback_type in self.callbacks:
for callback in self.callbacks.get(callback_type):
assert dataclasses.is_dataclass(callback) or callable(callback)
self.callbacks = {callback_type: [] for callback_type in CALLBACK_TYPES}

for callback_type in callbacks:
if callback_type not in CALLBACK_TYPES:
raise ValueError(
f"{callback_type} is not a supported callback type.\nSupported callback types include "
+ str(CALLBACK_TYPES)
)
# make sure callbacks are either dataclasses or functions
for callback in callbacks[callback_type]:
if not (dataclasses.is_dataclass(callback) or callable(callback)):
raise ValueError(
f"Callbacks must be of type dataclass or callable. Error found on the callback {callback} of type {callback_type}"
)
self.callbacks[callback_type].append(load_callable(callback))

def apply(self, trainer, callback_type: str):

Expand Down
1 change: 1 addition & 0 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def __init__(
# initialize callback manager
self.callback_manager, _ = instantiate(
builder=CallbackManager,
prefix="callbacks",
all_args=self.kwargs,
)

Expand Down

0 comments on commit 9f8fdf4

Please sign in to comment.