diff --git a/README.md b/README.md index ede2bdf..ba44384 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ pip install "torchvision>=0.13.*" python examples/gan.py ### contrastive_learning.py -pip install "torchvision>=0.13.*" "torchmetrics>=0.10.2" "scikit-learn>=1.1.*" +pip install "torchvision>=0.13.*" "scikit-learn>=1.1.*" python examples/contrastive_learning.py ### gnn.py gnn2.py diff --git a/examples/test_env.py b/examples/test_env.py index 183d2d9..600267b 100644 --- a/examples/test_env.py +++ b/examples/test_env.py @@ -140,7 +140,7 @@ def training_epoch_end(self) -> Dict[str, float]: ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) lmodel = MyLModule(model, [optimizer], "loss") trainer = ml.Trainer(lmodel, [0], 20, RUNS_DIR, gradient_clip_norm=10, - val_every_n_epoch=10, verbose=True, model_fpath=ckpt_path) + val_every_n_epoch=10, verbose=True, ckpt_fpath=ckpt_path) logger.info(trainer.test(ldm.val_dataloader, True, True)) logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader)) logger.info(trainer.test(ldm.test_dataloader, True, True)) @@ -152,5 +152,5 @@ def training_epoch_end(self) -> Dict[str, float]: model = MLP_L2(2, 4, 1) ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64) lmodel = MyLModule(model, [], "loss") - trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, model_fpath=ckpt_path) + trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, ckpt_fpath=ckpt_path) logger.info(trainer.test(ldm.test_dataloader, True, True)) diff --git a/mini_lightning/_mini_lightning.py b/mini_lightning/_mini_lightning.py index a688cb6..1fa8fd1 100644 --- a/mini_lightning/_mini_lightning.py +++ b/mini_lightning/_mini_lightning.py @@ -313,7 +313,7 @@ def __init__( gradient_clip_norm: Optional[float] = None, sync_bn: bool = False, replace_sampler_ddp: bool = True, - model_fpath: Optional[str] = None, + ckpt_fpath: Optional[str] = None, # val_every_n_epoch: int = 1, log_every_n_steps: int = 10, @@ -360,7 +360,7 @@ def __init__( replace_sampler_ddp=False: each gpu will use the complete dataset. replace_sampler_ddp=True: It will slice the dataset into world_size chunks and distribute them to each gpu. note: Replace train_dataloader only. Because DDP uses a single gpu for val/test. - model_fpath: only load model_state_dict. + ckpt_fpath: only load model_state_dict. If you want to resume from ckpt. please see `save_optimizers_state_dict` and examples in `examples/test_env.py` * val_every_n_epoch: Frequency of validation and prog_bar_leave of training. (the last epoch will always be validated) @@ -472,9 +472,10 @@ def __init__( hparams = lmodel.hparams self.save_hparams(hparams) # - if model_fpath is not None: - self._load_ckpt(model_fpath) - logger.info(f"Using ckpt: {model_fpath}") + self.ckpt_fpath = ckpt_fpath + if ckpt_fpath is not None: + self._load_ckpt(ckpt_fpath) + logger.info(f"Using ckpt: {ckpt_fpath}") lmodel.trainer_init(self) for s in lmodel._models: model: Module = getattr(lmodel, s)