Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 24, 2022
1 parent 02c5cf2 commit e12e5ed
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 21 deletions.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## Introduction
1. [Mini-Lightning](https://github.com/ustcml/mini-lightning/) is a lightweight machine learning training library, which is a mini version of [Pytorch-Lightning](https://www.pytorchlightning.ai/) with only 1k lines of code. It has the advantages of faster, more concise and more flexible.
2. Existing features: support for DDP(multi-node and multi-gpu), Sync-BN, DP, AMP, gradient accumulation, warmup and lr_scheduler, grad clip, tensorboard, model and result saving, beautiful console log, torchmetrics, etc.
2. Existing features: support for DDP(multi-node and multi-gpu), Sync-BN, DP, AMP, gradient accumulation, warmup and lr_scheduler, grad clip, tensorboard, model and result saving, beautiful console log, torchmetrics, resume from ckpt, etc.
3. Only the minimal interfaces are exposed, keeping the features of simplicity, easy to read, use and extend.
4. examples can be found in `examples/`
5. If you have any problems or bug finding, please raise issue, Thank you.
Expand Down Expand Up @@ -64,12 +64,6 @@ torchrun --nnodes 2 --node_rank 1 --master_addr xxx.xxx.xxx.xxx --nproc_per_node
```


## Environment
1. python>=3.8
2. torch>=1.12
3. torchmetrics==0.9.3


## TODO
1. GAN support
2. Automatic parameter adjustment
1 change: 1 addition & 0 deletions examples/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchmetrics.classification.precision_recall import Precision, Recall
from torchmetrics.classification.f_beta import F1Score, FBetaScore
from torchmetrics.classification.auroc import AUROC
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.functional.classification.accuracy import accuracy
#
import torch
Expand Down
18 changes: 13 additions & 5 deletions examples/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,39 @@ def training_epoch_end(self) -> Dict[str, float]:
#
model = MLP_L2(2, 4, 1)
optimizer = optim.SGD(model.parameters(), 0.1, 0.9)
#
lmodel = MyLModule(model, optimizer, loss_fn, metrics, "acc")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [], 40, RUNS_DIR, gradient_clip_norm=10, val_every_n_epoch=10, verbose=True)
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

# train from ckpt
# train from ckpt (model, optimizer state dict, global epoch, global step)
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
optimizer = optim.SGD(model.parameters(), 0.1, 0.9)
#
lmodel = MyLModule(None, optimizer, loss_fn, metrics, "loss")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [0], 100, RUNS_DIR, gradient_clip_norm=10,
val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path)
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))
# train from ckpt different optimizer (only model)
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
optimizer = optim.Adam(model.parameters(), 0.001)
lmodel = MyLModule(None, optimizer, loss_fn, metrics, "loss")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [0], 20, RUNS_DIR, gradient_clip_norm=10,
val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path)
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

# only test from ckpt
# only test from ckpt (model, global epoch, global step)
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
#
lmodel = MyLModule(None, None, loss_fn, metrics, "loss")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, resume_from_ckpt=ckpt_path)
Expand Down
27 changes: 20 additions & 7 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(
#
if resume_from_ckpt is not None:
logger.info(f"Using ckpt: {resume_from_ckpt}")
self._load_ckpt(resume_from_ckpt, self.device)
self._load_ckpt(resume_from_ckpt, self.device, True)
#
self.lmodel.trainer_init(self)
print_model_info(lmodel.model, None)
Expand Down Expand Up @@ -513,8 +513,9 @@ def _better_equal(metric: float, old_metric: Optional[float], higher_is_better:
def _save_ckpt(self, fpath: str) -> None:
if self.rank not in {-1, 0}:
return
kwargs = {
kwargs: Dict[str, Any] = {
"global_step": self.global_step,
"optimizer_name": self.lmodel.optimizer.__class__.__name__,
"core_metric": {
"name": self.lmodel.core_metric_name,
"higher_is_better": self.lmodel.higher_is_better,
Expand All @@ -523,12 +524,24 @@ def _save_ckpt(self, fpath: str) -> None:
}
save_ckpt(fpath, de_parallel(self.lmodel.model), self.lmodel.optimizer, self.global_epoch, **kwargs)

def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None) -> None:
def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None, verbose: bool = False) -> None:
new_model, optimizer_state_dict, mes = load_ckpt(fpath, map_location)
self.lmodel.model = new_model
self.lmodel.load_state_dict(None, optimizer_state_dict)
self.global_epoch = mes["last_epoch"]
self.global_step = mes["global_step"]
#
optimizer_name = self.lmodel.optimizer.__class__.__name__
tag = ["Ignore", "Ignore"]
if mes["optimizer_name"] == optimizer_name:
self.lmodel.load_state_dict(None, optimizer_state_dict)
tag[0] = "Success"

if mes["optimizer_name"] == optimizer_name or self.lmodel.optimizer is None:
self.global_epoch = mes["last_epoch"]
self.global_step = mes["global_step"]
tag[1] = "Success"

if verbose:
logger.info(
f"Using ckpt model: Success. optimizer state dict: {tag[0]}. global_epoch, global_step: {tag[1]}")

def _model_saving(self, core_metric: Optional[float]) -> bool:
best_saving = False
Expand Down Expand Up @@ -720,7 +733,7 @@ def _val_test(
lmodel.model = de_parallel(lmodel.model)
metrics_r: Dict[str, bool] = {k: m._to_sync for k, m in lmodel.metrics.items()}
for m in lmodel.metrics.values():
# torchmetrics ==0.9.3 private variable. I don't know whether it will be changed later.
# torchmetrics(>=0.9.3, <=0.10.0) private variable. I don't know whether it will be changed later.
# You can raise issue if finding error.
# default: sync_on_compute = True
m._to_sync = False
Expand Down
3 changes: 2 additions & 1 deletion mini_lightning/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def get_date_now(fmt: str = "%Y-%m-%d %H:%M:%S.%f") -> Tuple[str, Dict[str, int]
return date.strftime(fmt), mes


def save_ckpt(fpath: str, model: Module, optimizer: Optional[Optimizer], last_epoch: int, **kwargs) -> None:
def save_ckpt(fpath: str, model: Module, optimizer: Optional[Optimizer], last_epoch: int,
**kwargs: Dict[str, Any]) -> None:
ckpt: Dict[str, Any] = {
"model": model, # including model structure
"optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch>=1.12.*
torchmetrics==0.9.3
torchmetrics>=0.10.0
tensorboard
numpy
PyYAML
Expand Down

0 comments on commit e12e5ed

Please sign in to comment.