@@ -34,6 +34,7 @@ class CheckpointManager:
34
34
def __init__ (self ):
35
35
self .model_name = None
36
36
self .checkpoint_path = None
37
+ self .best_metric = None
37
38
38
39
def set_model_name (self , model_name ):
39
40
self .model_name = model_name
@@ -69,16 +70,33 @@ def init_checkpoint_dir(self):
69
70
print (f'checkpoint path : { self .checkpoint_path } ' )
70
71
71
72
def remove_last_model (self ):
72
- for last_model_path in glob (f'{ self .checkpoint_path } /model_*_iter .h5' ):
73
+ for last_model_path in glob (f'{ self .checkpoint_path } /last_* .h5' ):
73
74
os .remove (last_model_path )
74
75
75
- def save_last_model (self , model , iteration_count ):
76
+ def save_last_model (self , model , iteration_count , content = '' ):
76
77
self .make_checkpoint_dir ()
77
- save_path = f'{ self .checkpoint_path } /model_ { iteration_count } _iter.h5'
78
+ save_path = f'{ self .checkpoint_path } /last_ { iteration_count } _iter{ content } .h5'
78
79
model .save (save_path , include_optimizer = False )
79
80
backup_path = f'{ save_path } .bak'
80
81
sh .move (save_path , backup_path )
81
82
self .remove_last_model ()
82
83
sh .move (backup_path , save_path )
83
84
return save_path
84
85
86
+ def remove_best_model (self ):
87
+ for best_model_path in glob (f'{ self .checkpoint_path } /best_*.h5' ):
88
+ os .remove (best_model_path )
89
+
90
+ def save_best_model (self , model , iteration_count , metric , content = '' ):
91
+ save_path = None
92
+ if self .best_metric is None or metric > self .best_metric :
93
+ self .best_metric = metric
94
+ self .make_checkpoint_dir ()
95
+ save_path = f'{ self .checkpoint_path } /best_{ iteration_count } _iter{ content } .h5'
96
+ model .save (save_path , include_optimizer = False )
97
+ backup_path = f'{ save_path } .bak'
98
+ sh .move (save_path , backup_path )
99
+ self .remove_best_model ()
100
+ sh .move (backup_path , save_path )
101
+ return save_path
102
+
0 commit comments