@@ -500,6 +500,15 @@ def train_custom(
500
500
else :
501
501
self .optimizer = optimizer (params = self .model .parameters (), ** kwargs )
502
502
503
+ # load optimizer state if it exists
504
+ optimizer_state_loaded = False
505
+ if hasattr (self .model , "optimizer_state_dict" ) and self .model .optimizer_state_dict is not None :
506
+ try :
507
+ self .optimizer .load_state_dict (self .model .optimizer_state_dict )
508
+ optimizer_state_loaded = True
509
+ except Exception as e :
510
+ log .warning (f"Found saved optimizer state from previous training but coult not load: { e } " )
511
+
503
512
# initialize sampler if provided
504
513
if sampler is not None :
505
514
# init with default values if only class is provided
@@ -561,13 +570,17 @@ def train_custom(
561
570
log .info (f" (train_with_dev={ train_with_dev } , train_with_test={ train_with_test } )" )
562
571
log_line (log )
563
572
log .info ("Training Params:" )
573
+ log .info (f' - optimizer: "{ optimizer } " ' )
564
574
log .info (
565
575
f' - learning_rate: "{ learning_rate } " '
566
576
f'{ "(decoder: " + str (decoder_learning_rate ) + ")" if decoder_learning_rate else "" } '
567
577
)
568
578
log .info (f' - mini_batch_size: "{ mini_batch_size } "' )
569
579
log .info (f' - max_epochs: "{ max_epochs } "' )
570
580
log .info (f' - shuffle: "{ shuffle } "' )
581
+ if optimizer_state_loaded :
582
+ log_line (log )
583
+ log .info ("Optimizer state loaded from from previous training!" )
571
584
log_line (log )
572
585
log .info ("Plugins:" )
573
586
for plugin in plugins :
@@ -813,14 +826,14 @@ def wrapped_forward_loss(*args, **kwargs2):
813
826
814
827
if save_best_model and current_epoch_has_best_model_so_far :
815
828
log .info ("saving best model" )
816
- self ._save_model (base_path / "best-model.pt" , checkpoint = save_optimizer_state )
829
+ self ._save_model (base_path / "best-model.pt" , save_optimizer_state = save_optimizer_state )
817
830
818
831
# - SWAPlugin -> restores SGD weights from SWA
819
832
self .dispatch ("after_training_loop" )
820
833
821
834
# if we do not use dev data for model selection, save final model
822
835
if save_final_model :
823
- self ._save_model (base_path / "final-model.pt" , checkpoint = save_optimizer_state )
836
+ self ._save_model (base_path / "final-model.pt" , save_optimizer_state == save_optimizer_state )
824
837
825
838
except KeyboardInterrupt :
826
839
log_line (log )
@@ -830,7 +843,7 @@ def wrapped_forward_loss(*args, **kwargs2):
830
843
831
844
if save_final_model :
832
845
log .info ("Saving model ..." )
833
- self ._save_model (base_path / "final-model.pt" , checkpoint = save_optimizer_state )
846
+ self ._save_model (base_path / "final-model.pt" , save_optimizer_state = save_optimizer_state )
834
847
log .info ("Done." )
835
848
836
849
except TrainingInterrupt as exc :
@@ -841,7 +854,7 @@ def wrapped_forward_loss(*args, **kwargs2):
841
854
842
855
if save_final_model :
843
856
log .info ("Saving model ..." )
844
- self ._save_model (base_path / "final-model.pt" , checkpoint = save_optimizer_state )
857
+ self ._save_model (base_path / "final-model.pt" , save_optimizer_state = save_optimizer_state )
845
858
log .info ("Done." )
846
859
847
860
except Exception :
@@ -989,9 +1002,19 @@ def _record(self, metric):
989
1002
def _load_model (self , model_file : Union [str , Path ]) -> None :
990
1003
self .model .load_state_dict (self .model .load (model_file ).state_dict ())
991
1004
992
- def _save_model (self , model_file : Union [str , Path ], checkpoint : bool = False ) -> None :
1005
+ def _save_model (self , model_file : Union [str , Path ], save_optimizer_state : bool = False ) -> None :
993
1006
if is_main_process ():
994
- self .model .save (model_file , checkpoint )
1007
+ if save_optimizer_state :
1008
+ # Save optimizer state
1009
+ self .model .optimizer_state_dict = self .optimizer .state_dict ()
1010
+
1011
+ # Save scheduler state from active plugins
1012
+ for plugin in self .plugins :
1013
+ if hasattr (plugin , "scheduler" ):
1014
+ self .model .scheduler_state_dict = plugin .scheduler .state_dict ()
1015
+ break # Only save the first scheduler we find
1016
+
1017
+ self .model .save (model_file )
995
1018
if torch .distributed .is_initialized ():
996
1019
torch .distributed .barrier () # Prevent any process from loading a model until writing is complete
997
1020
0 commit comments