-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstep_checkpoint.py
38 lines (35 loc) · 1.42 KB
/
step_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import pytorch_lightning as pl
class StepCheckpoint(pl.Callback):
"""
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
based on validation loss.
"""
def __init__(
self,
save_step_frequency,
prefix="step_checkpoint",
use_modelcheckpoint_filename=False,
):
"""
Args:
save_step_frequency: how often to save in steps
prefix: add a prefix to the name, only used if
use_modelcheckpoint_filename=False
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
default filename, don't use ours.
"""
self.save_step_frequency = save_step_frequency
self.prefix = prefix
self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
def on_batch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train batch """
epoch = trainer.current_epoch
global_step = trainer.global_step
if global_step % self.save_step_frequency == 0:
if self.use_modelcheckpoint_filename:
filename = trainer.checkpoint_callback.filename
else:
filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
trainer.save_checkpoint(ckpt_path)