-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathtrain.py
94 lines (78 loc) · 2.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import glob
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from mGPT.callback import build_callbacks
from mGPT.config import parse_args, instantiate_from_config
from mGPT.data.build_data import build_data
from mGPT.models.build_model import build_model
from mGPT.utils.logger import create_logger
from mGPT.utils.load_checkpoint import load_pretrained, load_pretrained_vae
def main():
# Configs
cfg = parse_args(phase="train") # parse config file
# Logger
logger = create_logger(cfg, phase="train") # create logger
logger.info(OmegaConf.to_yaml(cfg)) # print config file
# Seed
pl.seed_everything(cfg.SEED_VALUE)
# Environment Variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Metric Logger
pl_loggers = []
for loggerName in cfg.LOGGER.TYPE:
if loggerName == 'tenosrboard' or cfg.LOGGER.WANDB.params.project:
pl_logger = instantiate_from_config(
eval(f'cfg.LOGGER.{loggerName.upper()}'))
pl_loggers.append(pl_logger)
# Callbacks
callbacks = build_callbacks(cfg, logger=logger, phase='train')
logger.info("Callbacks initialized")
# Dataset
datamodule = build_data(cfg)
logger.info("datasets module {} initialized".format("".join(
cfg.DATASET.target.split('.')[-2])))
# Model
model = build_model(cfg, datamodule)
logger.info("model {} loaded".format(cfg.model.target))
# Lightning Trainer
trainer = pl.Trainer(
default_root_dir=cfg.FOLDER_EXP,
max_epochs=cfg.TRAIN.END_EPOCH,
# precision='16',
logger=pl_loggers,
callbacks=callbacks,
check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS,
accelerator=cfg.ACCELERATOR,
devices=cfg.DEVICE,
num_nodes=cfg.NUM_NODES,
strategy="ddp_find_unused_parameters_true"
if len(cfg.DEVICE) > 1 else 'auto',
benchmark=False,
deterministic=False,
)
logger.info("Trainer initialized")
# Strict load pretrianed model
if cfg.TRAIN.PRETRAINED:
load_pretrained(cfg, model, logger)
# Strict load vae model
if cfg.TRAIN.PRETRAINED_VAE:
load_pretrained_vae(cfg, model, logger)
# Pytorch 2.0 Compile
# if torch.__version__ >= "2.0.0":
# model = torch.compile(model, mode="reduce-overhead")
# model = torch.compile(model)
# Lightning Fitting
if cfg.TRAIN.RESUME:
trainer.fit(model,
datamodule=datamodule,
ckpt_path=cfg.TRAIN.PRETRAINED)
else:
trainer.fit(model, datamodule=datamodule)
# Training ends
logger.info(
f"The outputs of this experiment are stored in {cfg.FOLDER_EXP}")
logger.info("Training ends!")
if __name__ == "__main__":
main()