From 4242ad4ecefeda3d2500d6fc34cfb4dfa9120deb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 25 Oct 2022 16:02:57 +0800 Subject: [PATCH] commit --- examples/test_env.py | 12 +++++++----- mini_lightning/_mini_lightning.py | 21 ++++++--------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/examples/test_env.py b/examples/test_env.py index 1897fdb..411af49 100644 --- a/examples/test_env.py +++ b/examples/test_env.py @@ -45,6 +45,7 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: def __len__(self) -> int: return self.n_samples + if __name__ == "__main__": ml.select_device([0]) ml.seed_everything(2, gpu_dtm=True) @@ -103,7 +104,7 @@ def training_epoch_end(self) -> Dict[str, float]: logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) - # train from ckpt (model, optimizer state dict, global epoch, global step) + # train from ckpt time.sleep(1) ckpt_path = trainer.last_ckpt_path optimizer = optim.SGD(model.parameters(), 0.1, 0.9) @@ -114,19 +115,20 @@ def training_epoch_end(self) -> Dict[str, float]: logger.info(trainer.test(ldm.val_dataloader, True, True)) logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) - # train from ckpt different optimizer (only model) + # train from ckpt (only model) time.sleep(1) ckpt_path = trainer.last_ckpt_path + model, _, _ = ml.load_ckpt(ckpt_path, Device(0)) optimizer = optim.Adam(model.parameters(), 0.001) - lmodel = MyLModule(None, optimizer, loss_fn, metrics, "loss") + lmodel = MyLModule(model, optimizer, loss_fn, metrics, "loss") ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) trainer = ml.Trainer(lmodel, [0], 20, RUNS_DIR, gradient_clip_norm=10, - val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path) + val_every_n_epoch=10, verbose=True) logger.info(trainer.test(ldm.val_dataloader, True, True)) logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) - # only test from ckpt (model, global epoch, global step) + # only test from ckpt time.sleep(1) ckpt_path = trainer.last_ckpt_path lmodel = MyLModule(None, None, loss_fn, metrics, "loss") diff --git a/mini_lightning/_mini_lightning.py b/mini_lightning/_mini_lightning.py index d16e66a..802abc2 100644 --- a/mini_lightning/_mini_lightning.py +++ b/mini_lightning/_mini_lightning.py @@ -422,8 +422,8 @@ def __init__( self.save_hparams(hparams) # if resume_from_ckpt is not None: + self._load_ckpt(resume_from_ckpt, self.device) logger.info(f"Using ckpt: {resume_from_ckpt}") - self._load_ckpt(resume_from_ckpt, self.device, True) # self.lmodel.trainer_init(self) print_model_info(lmodel.model, None) @@ -524,24 +524,15 @@ def _save_ckpt(self, fpath: str) -> None: } save_ckpt(fpath, de_parallel(self.lmodel.model), self.lmodel.optimizer, self.global_epoch, **kwargs) - def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None, verbose: bool = False) -> None: + def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None) -> None: new_model, optimizer_state_dict, mes = load_ckpt(fpath, map_location) self.lmodel.model = new_model # optimizer_name = self.lmodel.optimizer.__class__.__name__ - tag = ["Ignore", "Ignore"] - if mes["optimizer_name"] == optimizer_name: - self.lmodel.load_state_dict(None, optimizer_state_dict) - tag[0] = "Success" - - if mes["optimizer_name"] == optimizer_name or self.lmodel.optimizer is None: - self.global_epoch = mes["last_epoch"] - self.global_step = mes["global_step"] - tag[1] = "Success" - - if verbose: - logger.info( - f"Using ckpt model: Success. optimizer state dict: {tag[0]}. global_epoch, global_step: {tag[1]}") + assert self.lmodel.optimizer is None or optimizer_name == mes["optimizer_name"] + self.lmodel.load_state_dict(None, optimizer_state_dict) + self.global_epoch = mes["last_epoch"] + self.global_step = mes["global_step"] def _model_saving(self, core_metric: Optional[float]) -> bool: best_saving = False