From c8f977ad28fa9575ce49d2ea1bd948cb25a9eaaf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 10 Oct 2022 23:11:20 +0800 Subject: [PATCH] commit --- examples/cv.py | 2 +- examples/cv_ddp.py | 2 +- examples/cv_ddp_spawn.py | 2 +- examples/nlp.py | 2 +- examples/test_env.py | 6 +++--- mini_lightning/mini_lightning.py | 31 ++++++++++++++++++++++++++----- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/examples/cv.py b/examples/cv.py index 7631407..208a885 100644 --- a/examples/cv.py +++ b/examples/cv.py @@ -127,7 +127,7 @@ def collect_res(seed: int) -> Dict[str, float]: lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams) trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"]) res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader) - res2 = trainer.test(ldm.test_dataloader) + res2 = trainer.test(ldm.test_dataloader, True, True) res.update(res2) return res res = ml.multi_runs(collect_res, 3, seed=42) diff --git a/examples/cv_ddp.py b/examples/cv_ddp.py index c7cdf90..1a1bc2f 100644 --- a/examples/cv_ddp.py +++ b/examples/cv_ddp.py @@ -142,7 +142,7 @@ def collect_res(seed: int) -> Dict[str, float]: lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams) trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"]) res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader) - res2 = trainer.test(ldm.test_dataloader) + res2 = trainer.test(ldm.test_dataloader, True, True) res.update(res2) return res res = ml.multi_runs(collect_res, 3, seed=42) diff --git a/examples/cv_ddp_spawn.py b/examples/cv_ddp_spawn.py index ed66467..cac8e5c 100644 --- a/examples/cv_ddp_spawn.py +++ b/examples/cv_ddp_spawn.py @@ -137,7 +137,7 @@ def collect_res(seed: int) -> Dict[str, float]: lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams) trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"]) res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader) - res2 = trainer.test(ldm.test_dataloader) + res2 = trainer.test(ldm.test_dataloader, True, True) res.update(res2) return res res = ml.multi_runs(collect_res, 3, seed=42) diff --git a/examples/nlp.py b/examples/nlp.py index f9e3f0f..f727d0d 100644 --- a/examples/nlp.py +++ b/examples/nlp.py @@ -122,4 +122,4 @@ def tokenize_function(example): logger.info("KeyboardInterrupt Detected...") raise finally: - logger.info(trainer.test(ldm.test_dataloader)) + logger.info(trainer.test(ldm.test_dataloader, True, True)) diff --git a/examples/test_env.py b/examples/test_env.py index 98d4705..abf1e80 100644 --- a/examples/test_env.py +++ b/examples/test_env.py @@ -100,7 +100,7 @@ def training_epoch_end(self) -> Dict[str, float]: lmodel = MyLModule(model, optimizer, loss_fn, metrics, "acc") ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) trainer = ml.Trainer(lmodel, [], 40, RUNS_DIR, gradient_clip_norm=10, val_every_n_epoch=10, verbose=True) - logger.info(trainer.test(ldm.val_dataloader, False, True)) + logger.info(trainer.test(ldm.val_dataloader, True, True)) logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) @@ -113,7 +113,7 @@ def training_epoch_end(self) -> Dict[str, float]: ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) trainer = ml.Trainer(lmodel, [0], 100, RUNS_DIR, gradient_clip_norm=10, val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path) - logger.info(trainer.test(ldm.val_dataloader, False, True)) + logger.info(trainer.test(ldm.val_dataloader, True, True)) logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) @@ -124,4 +124,4 @@ def training_epoch_end(self) -> Dict[str, float]: lmodel = MyLModule(None, None, loss_fn, metrics, "loss") ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, resume_from_ckpt=ckpt_path) - logger.info(trainer.test(ldm.test_dataloader, False, True)) + logger.info(trainer.test(ldm.test_dataloader, True, True)) diff --git a/mini_lightning/mini_lightning.py b/mini_lightning/mini_lightning.py index 9343f41..e7b1755 100644 --- a/mini_lightning/mini_lightning.py +++ b/mini_lightning/mini_lightning.py @@ -846,15 +846,36 @@ def fit(self, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] cuda.empty_cache() return best_mes if self.rank in {-1, 0} else {} # core_metrics is best - def test(self, dataloader: Optional[DataLoader], test_best: bool = True, test_last: bool = False) -> Dict[str, float]: + def _best_ckpt_is_last(self) -> bool: + if self.best_ckpt_path is None or self.last_ckpt_path is None: + return False + + best_ckpt_fname = os.path.basename(self.best_ckpt_path) + m = re.match(r"best-epoch=(\d+)", best_ckpt_fname) + assert m is not None + best_epoch_idx = m.group(1) + last_ckpt_fname = os.path.basename(self.last_ckpt_path) + m = re.match(r"last-epoch=(\d+)", last_ckpt_fname) + assert m is not None + last_epoch_idx = m.group(1) + return best_epoch_idx == last_epoch_idx + + def test(self, dataloader: Optional[DataLoader], test_best: bool = False, test_last: bool = True) -> Dict[str, float]: res_mes = {} if test_best: # If last first, last will be overridden in tensorboard. So best first. - m = self._test(dataloader, "best") - res_mes.update(m) + if self.best_ckpt_path is None: + logger.warning("Ignore test best: self.best_ckpt_path is None") + test_best = False + else: + m = self._test(dataloader, "best") + res_mes.update(m) # if test_last: # just current model - m = self._test(dataloader, "last") - res_mes.update(m) + if self._best_ckpt_is_last() and test_best is True: + logger.info("Ignore test last: the best ckpt is the last ckpt") + else: + m = self._test(dataloader, "last") + res_mes.update(m) cuda.empty_cache() return res_mes if self.rank in {-1, 0} else {}