|
1 | 1 | import logging |
| 2 | +import warnings |
2 | 3 | from typing import NamedTuple |
3 | 4 |
|
4 | 5 | import torch |
5 | 6 | import torch.distributed as dist |
6 | 7 | from torch.distributed.fsdp import FSDPModule |
7 | 8 | from torch.utils.data import DataLoader |
| 9 | +from tqdm import TqdmExperimentalWarning |
8 | 10 | from tqdm.rich import tqdm |
9 | 11 | from transformers import PreTrainedModel |
10 | 12 |
|
|
17 | 19 | root_logger = logging.getLogger("speculators") |
18 | 20 | metric_logger = logging.getLogger("speculators.metrics") |
19 | 21 |
|
| 22 | +warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) |
| 23 | + |
20 | 24 |
|
21 | 25 | class TrainerConfig(NamedTuple): |
22 | 26 | lr: float |
@@ -47,19 +51,27 @@ def __init__( |
47 | 51 | checkpointer_class = ( |
48 | 52 | DistributedCheckpointer if self.is_distributed else SingleGPUCheckpointer |
49 | 53 | ) |
50 | | - self.checkpointer = checkpointer_class( |
51 | | - self.config.save_path, |
52 | | - try_load_last_checkpoint=self.config.resume_from_checkpoint, |
53 | | - ) |
| 54 | + self.checkpointer = checkpointer_class(self.config.save_path) |
54 | 55 |
|
55 | 56 | self.setup_trainer() |
56 | 57 | self.setup_model() |
57 | 58 | self.setup_optimizer() |
58 | 59 |
|
59 | 60 | def setup_trainer(self): |
60 | | - if self.resume_from_checkpoint: |
| 61 | + if self.checkpointer.previous_epoch != -1: |
| 62 | + root_logger.info(f"Found checkpoint at {self.checkpointer.prev_path}.") |
61 | 63 | self.current_epoch = self.checkpointer.previous_epoch + 1 |
| 64 | + if self.resume_from_checkpoint: |
| 65 | + root_logger.info(f"Resuming training on {self.current_epoch} epoch.") |
| 66 | + else: |
| 67 | + root_logger.warning( |
| 68 | + "`resume_from_checkpoint` is False, starting " |
| 69 | + "training from scratch. This will overwrite the " |
| 70 | + f"existing checkpoints in {self.checkpointer.path}." |
| 71 | + ) |
| 72 | + self.current_epoch = 0 |
62 | 73 | else: |
| 74 | + root_logger.info("No previous checkpoint found. Starting from scratch.") |
63 | 75 | self.current_epoch = 0 |
64 | 76 | self.global_step = 0 |
65 | 77 |
|
@@ -99,7 +111,6 @@ def train_epoch(self, epoch: int): |
99 | 111 | train_loader = self.train_loader |
100 | 112 | if self.local_rank == 0: |
101 | 113 | train_loader = tqdm(train_loader, desc=f"Epoch {epoch}") # type: ignore[assignment] |
102 | | - root_logger.info(f"Training Epoch {epoch} started") |
103 | 114 |
|
104 | 115 | for batch in train_loader: |
105 | 116 | gpu_batch = { |
@@ -131,17 +142,13 @@ def train_epoch(self, epoch: int): |
131 | 142 | ) |
132 | 143 | self.global_step += 1 |
133 | 144 |
|
134 | | - root_logger.info(f"Training Epoch {epoch} completed") |
135 | | - |
136 | 145 | @torch.no_grad() |
137 | 146 | def val_epoch(self, epoch: int): |
138 | 147 | if self.val_loader is None: |
139 | | - root_logger.warning("No val loader, skipping validation") |
140 | 148 | return |
141 | 149 | self.model.eval() |
142 | 150 | if hasattr(self.val_loader.batch_sampler, "set_epoch"): |
143 | 151 | self.val_loader.batch_sampler.set_epoch(epoch) # type: ignore[union-attr] |
144 | | - root_logger.info(f"Validation Epoch {epoch} started") |
145 | 152 | val_loader = self.val_loader |
146 | 153 | if self.local_rank == 0: |
147 | 154 | val_loader = tqdm(val_loader, desc=f"Epoch {epoch}") # type: ignore[assignment] |
@@ -176,16 +183,34 @@ def val_epoch(self, epoch: int): |
176 | 183 | {"val": {"loss_epoch": val_loss.item(), **acc_values}, "epoch": epoch}, |
177 | 184 | extra={"step": self.global_step}, |
178 | 185 | ) |
179 | | - root_logger.info(f"Validation Epoch {epoch} completed") |
180 | 186 |
|
181 | 187 | def save_checkpoint(self, epoch: int): |
182 | 188 | self.checkpointer.save_checkpoint(self.model, self.opt, epoch) |
183 | | - root_logger.info(f"Checkpoint saved to {self.checkpointer.path / str(epoch)}") |
184 | 189 |
|
185 | 190 | def run_training(self): |
186 | | - for epoch in range(self.current_epoch, self.config.num_epochs): |
| 191 | + n_epochs = self.config.num_epochs |
| 192 | + for epoch in range(self.current_epoch, n_epochs): |
| 193 | + root_logger.info(f"Training epoch {epoch + 1}/{n_epochs} started") |
187 | 194 | self.train_epoch(epoch) |
| 195 | + root_logger.info(f"Training epoch {epoch + 1}/{n_epochs} completed") |
| 196 | + |
188 | 197 | if self.is_distributed: |
189 | 198 | dist.barrier() |
190 | | - self.val_epoch(epoch) |
| 199 | + |
| 200 | + if self.val_loader is None: |
| 201 | + root_logger.warning("No val loader, skipping validation epoch") |
| 202 | + else: |
| 203 | + root_logger.info(f"Validation epoch {epoch + 1}/{n_epochs} started") |
| 204 | + self.val_epoch(epoch) |
| 205 | + root_logger.info(f"Validation epoch {epoch + 1}/{n_epochs} completed") |
| 206 | + |
| 207 | + if self.is_distributed: |
| 208 | + dist.barrier() |
| 209 | + |
| 210 | + root_logger.info( |
| 211 | + f"Started saving checkpoint to {self.checkpointer.path / str(epoch)}" |
| 212 | + ) |
191 | 213 | self.save_checkpoint(epoch) |
| 214 | + root_logger.info( |
| 215 | + f"Finished saving checkpoint to {self.checkpointer.path / str(epoch)}" |
| 216 | + ) |
0 commit comments