Skip to content

Commit 583d5c9

Browse files
committed
Improve trainer status logging
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 246dc21 commit 583d5c9

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

src/speculators/train/checkpointer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ class BaseCheckpointer:
2828
...
2929
"""
3030

31-
def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True):
31+
def __init__(self, path: Path | str):
3232
self.path = Path(path)
33-
self.previous_epoch = (
34-
self._get_previous_epoch() if try_load_last_checkpoint else -1
35-
)
33+
self.previous_epoch = self._get_previous_epoch()
34+
35+
if self.previous_epoch != -1:
36+
self.prev_path: Path | None = self.path / str(self.previous_epoch)
37+
else:
38+
self.prev_path = None
3639

3740
@abstractmethod
3841
def load_model_state_dict(self, model: PreTrainedModel):

src/speculators/train/trainer.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
2+
import warnings
23
from typing import NamedTuple
34

45
import torch
56
import torch.distributed as dist
67
from torch.distributed.fsdp import FSDPModule
78
from torch.utils.data import DataLoader
9+
from tqdm import TqdmExperimentalWarning
810
from tqdm.rich import tqdm
911
from transformers import PreTrainedModel
1012

@@ -17,6 +19,8 @@
1719
root_logger = logging.getLogger("speculators")
1820
metric_logger = logging.getLogger("speculators.metrics")
1921

22+
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
23+
2024

2125
class TrainerConfig(NamedTuple):
2226
lr: float
@@ -47,19 +51,27 @@ def __init__(
4751
checkpointer_class = (
4852
DistributedCheckpointer if self.is_distributed else SingleGPUCheckpointer
4953
)
50-
self.checkpointer = checkpointer_class(
51-
self.config.save_path,
52-
try_load_last_checkpoint=self.config.resume_from_checkpoint,
53-
)
54+
self.checkpointer = checkpointer_class(self.config.save_path)
5455

5556
self.setup_trainer()
5657
self.setup_model()
5758
self.setup_optimizer()
5859

5960
def setup_trainer(self):
60-
if self.resume_from_checkpoint:
61+
if self.checkpointer.previous_epoch != -1:
62+
root_logger.info(f"Found checkpoint at {self.checkpointer.prev_path}.")
6163
self.current_epoch = self.checkpointer.previous_epoch + 1
64+
if self.resume_from_checkpoint:
65+
root_logger.info(f"Resuming training on {self.current_epoch} epoch.")
66+
else:
67+
root_logger.warning(
68+
"`resume_from_checkpoint` is False, starting "
69+
"training from scratch. This will overwrite the "
70+
f"existing checkpoints in {self.checkpointer.path}."
71+
)
72+
self.current_epoch = 0
6273
else:
74+
root_logger.info("No previous checkpoint found. Starting from scratch.")
6375
self.current_epoch = 0
6476
self.global_step = 0
6577

@@ -99,7 +111,6 @@ def train_epoch(self, epoch: int):
99111
train_loader = self.train_loader
100112
if self.local_rank == 0:
101113
train_loader = tqdm(train_loader, desc=f"Epoch {epoch}") # type: ignore[assignment]
102-
root_logger.info(f"Training Epoch {epoch} started")
103114

104115
for batch in train_loader:
105116
gpu_batch = {
@@ -131,17 +142,13 @@ def train_epoch(self, epoch: int):
131142
)
132143
self.global_step += 1
133144

134-
root_logger.info(f"Training Epoch {epoch} completed")
135-
136145
@torch.no_grad()
137146
def val_epoch(self, epoch: int):
138147
if self.val_loader is None:
139-
root_logger.warning("No val loader, skipping validation")
140148
return
141149
self.model.eval()
142150
if hasattr(self.val_loader.batch_sampler, "set_epoch"):
143151
self.val_loader.batch_sampler.set_epoch(epoch) # type: ignore[union-attr]
144-
root_logger.info(f"Validation Epoch {epoch} started")
145152
val_loader = self.val_loader
146153
if self.local_rank == 0:
147154
val_loader = tqdm(val_loader, desc=f"Epoch {epoch}") # type: ignore[assignment]
@@ -176,16 +183,34 @@ def val_epoch(self, epoch: int):
176183
{"val": {"loss_epoch": val_loss.item(), **acc_values}, "epoch": epoch},
177184
extra={"step": self.global_step},
178185
)
179-
root_logger.info(f"Validation Epoch {epoch} completed")
180186

181187
def save_checkpoint(self, epoch: int):
182188
self.checkpointer.save_checkpoint(self.model, self.opt, epoch)
183-
root_logger.info(f"Checkpoint saved to {self.checkpointer.path / str(epoch)}")
184189

185190
def run_training(self):
186-
for epoch in range(self.current_epoch, self.config.num_epochs):
191+
n_epochs = self.config.num_epochs
192+
for epoch in range(self.current_epoch, n_epochs):
193+
root_logger.info(f"Training epoch {epoch + 1}/{n_epochs} started")
187194
self.train_epoch(epoch)
195+
root_logger.info(f"Training epoch {epoch + 1}/{n_epochs} completed")
196+
188197
if self.is_distributed:
189198
dist.barrier()
190-
self.val_epoch(epoch)
199+
200+
if self.val_loader is None:
201+
root_logger.warning("No val loader, skipping validation epoch")
202+
else:
203+
root_logger.info(f"Validation epoch {epoch + 1}/{n_epochs} started")
204+
self.val_epoch(epoch)
205+
root_logger.info(f"Validation epoch {epoch + 1}/{n_epochs} completed")
206+
207+
if self.is_distributed:
208+
dist.barrier()
209+
210+
root_logger.info(
211+
f"Started saving checkpoint to {self.checkpointer.path / str(epoch)}"
212+
)
191213
self.save_checkpoint(epoch)
214+
root_logger.info(
215+
f"Finished saving checkpoint to {self.checkpointer.path / str(epoch)}"
216+
)

0 commit comments

Comments
 (0)