Skip to content

Commit 4571b4c

Browse files
authored
Merge pull request #3640 from flairNLP/GH-3444-save-optimizer
Add option to save optimizer and scheduler state during training, and to resume training from these states
2 parents bbb7b66 + 0c733f0 commit 4571b4c

File tree

5 files changed

+96
-10
lines changed

5 files changed

+96
-10
lines changed

flair/nn/model.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,17 @@ class Model(torch.nn.Module, typing.Generic[DT], ABC):
3232
Every new type of model must implement these methods.
3333
"""
3434

35-
model_card: Optional[dict[str, Any]] = None
35+
def __init__(self) -> None:
36+
super().__init__()
37+
38+
# The model card can contain training parameters and metadata
39+
self.model_card: Optional[dict[str, Any]] = None
40+
41+
# Optimizer and scheduler states are only set during training when save_optimizer_state=True
42+
# is passed to the ModelTrainer. These states allow resuming training from a checkpoint
43+
# with the exact same optimizer and learning rate scheduler states.
44+
self.optimizer_state_dict: Optional[dict[str, Any]] = None
45+
self.scheduler_state_dict: Optional[dict[str, Any]] = None
3646

3747
@property
3848
@abstractmethod
@@ -86,10 +96,26 @@ def evaluate(
8696
raise NotImplementedError
8797

8898
def _get_state_dict(self) -> dict:
89-
"""Returns the state dictionary for this model."""
99+
"""Returns the state dictionary for this model.
100+
101+
The state dictionary contains:
102+
- "state_dict": The model's parameters state dictionary
103+
- "__cls__": The class name of the model for loading
104+
- "optimizer_state_dict": The optimizer's state dictionary (if it exists)
105+
- "scheduler_state_dict": The scheduler's state dictionary (if it exists)
106+
- "model_card": Training parameters and metadata (if set)
107+
"""
90108
# Always include the name of the Model class for which the state dict holds
91109
state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}
92110

111+
# Add optimizer state dict if it exists
112+
if hasattr(self, "optimizer_state_dict") and self.optimizer_state_dict is not None:
113+
state_dict["optimizer_state_dict"] = self.optimizer_state_dict
114+
115+
# Add scheduler state dict if it exists
116+
if hasattr(self, "scheduler_state_dict") and self.scheduler_state_dict is not None:
117+
state_dict["scheduler_state_dict"] = self.scheduler_state_dict
118+
93119
return state_dict
94120

95121
@classmethod
@@ -105,6 +131,16 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
105131

106132
model.load_state_dict(state["state_dict"])
107133

134+
# load optimizer state if it exists in the state dict
135+
if "optimizer_state_dict" in state:
136+
log.debug(f"Found optimizer state in model file with keys: {state['optimizer_state_dict'].keys()}")
137+
model.optimizer_state_dict = state["optimizer_state_dict"]
138+
139+
# load scheduler state if it exists in the state dict
140+
if "scheduler_state_dict" in state:
141+
log.debug(f"Found scheduler state in model file with keys: {state['scheduler_state_dict'].keys()}")
142+
model.scheduler_state_dict = state["scheduler_state_dict"]
143+
108144
return model
109145

110146
@staticmethod

flair/trainers/plugins/functional/anneal_on_plateau.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,34 @@ def after_setup(
5858
anneal_mode = "min" if train_with_dev else "max"
5959

6060
# instantiate the scheduler
61-
self.scheduler: AnnealOnPlateau = AnnealOnPlateau(
61+
self.scheduler = AnnealOnPlateau(
6262
factor=self.anneal_factor,
6363
patience=self.patience,
6464
initial_extra_patience=self.initial_extra_patience,
6565
mode=anneal_mode,
6666
optimizer=self.trainer.optimizer,
6767
)
6868

69+
# Load scheduler state if it exists
70+
if hasattr(self.trainer.model, "scheduler_state_dict") and self.trainer.model.scheduler_state_dict is not None:
71+
try:
72+
log.info("Found saved scheduler state, loading it...")
73+
self.scheduler.load_state_dict(self.trainer.model.scheduler_state_dict)
74+
log.info("Scheduler state loaded successfully!")
75+
except Exception as e:
76+
log.warning(f"Could not load scheduler state: {e}")
77+
6978
self.store_learning_rate()
7079

7180
@TrainerPlugin.hook
7281
def after_evaluation(self, current_model_is_best, validation_scores, **kw):
7382
"""Scheduler step of AnnealOnPlateau."""
7483
reduced_learning_rate: bool = self.scheduler.step(*validation_scores)
7584

85+
# Save scheduler state after step
86+
if hasattr(self.trainer.model, "save_optimizer_state") and self.trainer.model.save_optimizer_state:
87+
self.trainer.model.scheduler_state_dict = self.scheduler.state_dict()
88+
7689
self.store_learning_rate()
7790

7891
bad_epochs = self.scheduler.num_bad_epochs

flair/trainers/plugins/functional/checkpoints.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def after_training_epoch(self, epoch, **kw):
2727
f"was set"
2828
)
2929
model_name = "model_epoch_" + str(epoch) + ".pt"
30-
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
30+
31+
# Use trainer's _save_model method - we have access to trainer through self.trainer
32+
self.trainer._save_model(self.base_path / model_name, save_optimizer_state=self.save_optimizer_state)
3133

3234
@property
3335
def attach_to_all_processes(self) -> bool:

flair/trainers/plugins/functional/linear_scheduler.py

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def after_setup(
4545
num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, optimizer=self.trainer.optimizer
4646
)
4747

48+
# Load scheduler state if it exists
49+
if hasattr(self.trainer.model, "scheduler_state_dict") and self.trainer.model.scheduler_state_dict is not None:
50+
try:
51+
log.info("Found saved scheduler state, loading it...")
52+
self.scheduler.load_state_dict(self.trainer.model.scheduler_state_dict)
53+
log.info("Scheduler state loaded successfully!")
54+
except Exception as e:
55+
log.warning(f"Could not load scheduler state: {e}")
56+
4857
self.store_learning_rate()
4958

5059
@TrainerPlugin.hook
@@ -60,6 +69,9 @@ def after_training_batch(self, optimizer_was_run: bool, **kwargs):
6069
if not optimizer_was_run:
6170
return
6271
self.scheduler.step()
72+
# Save scheduler state after step
73+
if hasattr(self.trainer.model, "save_optimizer_state") and self.trainer.model.save_optimizer_state:
74+
self.trainer.model.scheduler_state_dict = self.scheduler.state_dict()
6375
self.store_learning_rate()
6476

6577
def __str__(self) -> str:

flair/trainers/trainer.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,15 @@ def train_custom(
500500
else:
501501
self.optimizer = optimizer(params=self.model.parameters(), **kwargs)
502502

503+
# load optimizer state if it exists
504+
optimizer_state_loaded = False
505+
if hasattr(self.model, "optimizer_state_dict") and self.model.optimizer_state_dict is not None:
506+
try:
507+
self.optimizer.load_state_dict(self.model.optimizer_state_dict)
508+
optimizer_state_loaded = True
509+
except Exception as e:
510+
log.warning(f"Found saved optimizer state from previous training but coult not load: {e}")
511+
503512
# initialize sampler if provided
504513
if sampler is not None:
505514
# init with default values if only class is provided
@@ -561,13 +570,17 @@ def train_custom(
561570
log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})")
562571
log_line(log)
563572
log.info("Training Params:")
573+
log.info(f' - optimizer: "{optimizer}" ')
564574
log.info(
565575
f' - learning_rate: "{learning_rate}" '
566576
f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}'
567577
)
568578
log.info(f' - mini_batch_size: "{mini_batch_size}"')
569579
log.info(f' - max_epochs: "{max_epochs}"')
570580
log.info(f' - shuffle: "{shuffle}"')
581+
if optimizer_state_loaded:
582+
log_line(log)
583+
log.info("Optimizer state loaded from from previous training!")
571584
log_line(log)
572585
log.info("Plugins:")
573586
for plugin in plugins:
@@ -813,14 +826,14 @@ def wrapped_forward_loss(*args, **kwargs2):
813826

814827
if save_best_model and current_epoch_has_best_model_so_far:
815828
log.info("saving best model")
816-
self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state)
829+
self._save_model(base_path / "best-model.pt", save_optimizer_state=save_optimizer_state)
817830

818831
# - SWAPlugin -> restores SGD weights from SWA
819832
self.dispatch("after_training_loop")
820833

821834
# if we do not use dev data for model selection, save final model
822835
if save_final_model:
823-
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
836+
self._save_model(base_path / "final-model.pt", save_optimizer_state == save_optimizer_state)
824837

825838
except KeyboardInterrupt:
826839
log_line(log)
@@ -830,7 +843,7 @@ def wrapped_forward_loss(*args, **kwargs2):
830843

831844
if save_final_model:
832845
log.info("Saving model ...")
833-
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
846+
self._save_model(base_path / "final-model.pt", save_optimizer_state=save_optimizer_state)
834847
log.info("Done.")
835848

836849
except TrainingInterrupt as exc:
@@ -841,7 +854,7 @@ def wrapped_forward_loss(*args, **kwargs2):
841854

842855
if save_final_model:
843856
log.info("Saving model ...")
844-
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
857+
self._save_model(base_path / "final-model.pt", save_optimizer_state=save_optimizer_state)
845858
log.info("Done.")
846859

847860
except Exception:
@@ -989,9 +1002,19 @@ def _record(self, metric):
9891002
def _load_model(self, model_file: Union[str, Path]) -> None:
9901003
self.model.load_state_dict(self.model.load(model_file).state_dict())
9911004

992-
def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
1005+
def _save_model(self, model_file: Union[str, Path], save_optimizer_state: bool = False) -> None:
9931006
if is_main_process():
994-
self.model.save(model_file, checkpoint)
1007+
if save_optimizer_state:
1008+
# Save optimizer state
1009+
self.model.optimizer_state_dict = self.optimizer.state_dict()
1010+
1011+
# Save scheduler state from active plugins
1012+
for plugin in self.plugins:
1013+
if hasattr(plugin, "scheduler"):
1014+
self.model.scheduler_state_dict = plugin.scheduler.state_dict()
1015+
break # Only save the first scheduler we find
1016+
1017+
self.model.save(model_file)
9951018
if torch.distributed.is_initialized():
9961019
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
9971020

0 commit comments

Comments
 (0)