Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
December 2021 update
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Dec 23, 2021
1 parent e844c5a commit eb77366
Show file tree
Hide file tree
Showing 39 changed files with 534 additions and 468 deletions.
82 changes: 40 additions & 42 deletions adv/train/mulang/train_m2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from utils.base import *
from utils.init import init_model_params
from utils.contpara import get_model_parameters
from utils.h5serial import h5save, h5load
from utils.fmt.base import tostr, save_states, load_states, pad_id
from utils.state.holder import Holder
from utils.state.pyrand import PyRandomState
from utils.state.thrand import THRandomState
from utils.fmt.base import tostr, pad_id
from utils.fmt.base4torch import parse_cuda, load_emb
from utils.mulang import data_sampler

Expand All @@ -30,7 +32,7 @@

from transformer.NMT import NMT

def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None):
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None):

sum_loss = part_loss = 0.0
sum_wd = part_wd = 0
Expand Down Expand Up @@ -77,17 +79,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
if num_checkpoint > 1:
_fend = "_%d.h5" % (_cur_checkid)
_chkpf = chkpf[:-3] + _fend
if chkpof is not None:
_chkpof = chkpof[:-3] + _fend
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
else:
_chkpf = chkpf
_chkpof = chkpof
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
if chkpof is not None:
h5save(optm.state_dict(), _chkpof)
if statesf is not None:
save_states(statesf, tl[cur_b - 1:])
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
_cur_rstep -= 1
if _cur_rstep <= 0:
break
Expand All @@ -111,17 +108,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
if num_checkpoint > 1:
_fend = "_%d.h5" % (_cur_checkid)
_chkpf = chkpf[:-3] + _fend
if chkpof is not None:
_chkpof = chkpof[:-3] + _fend
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
else:
_chkpf = chkpf
_chkpof = chkpof
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
if chkpof is not None:
h5save(optm.state_dict(), _chkpof)
if statesf is not None:
save_states(statesf, tl[cur_b - 1:])
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
cur_b += 1
if part_wd != 0.0:
logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,))
Expand Down Expand Up @@ -181,7 +173,7 @@ def load_fixing(module):
batch_report = cnfg.batch_report
report_eva = cnfg.report_eva
use_ams = cnfg.use_ams
save_optm_state = cnfg.save_optm_state
cnt_states = cnfg.train_statesf
save_auto_clean = cnfg.save_auto_clean
overwrite_eva = cnfg.overwrite_eva
save_every = cnfg.save_every
Expand All @@ -193,14 +185,11 @@ def load_fixing(module):
mkdir(wkdir)

chkpf = None
chkpof = None
statesf = None
if save_every is not None:
chkpf = wkdir + "checkpoint.h5"
if save_optm_state:
chkpof = wkdir + "checkpoint.optm.h5"
if cnfg.save_train_state:
statesf = wkdir + "checkpoint.states"
if cnfg.save_train_state:
statesf = wkdir + "train.states.t7"

logger = get_logger(wkdir + "train.log")

Expand All @@ -217,10 +206,6 @@ def load_fixing(module):
nword = td["nword"][:].tolist()
nwordi, ntask, nwordt = nword[0], nword[1], nword[-1]

logger.info("Design models with seed: %d" % torch.initial_seed())
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)

fine_tune_m = cnfg.fine_tune_m
task_weight, task_weight_T = cnfg.task_weight, cnfg.task_weight_T
if task_weight_T is None or task_weight_T == 1.0:
tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][:].tolist()) for i in range(_nd)]
Expand All @@ -234,6 +219,11 @@ def load_fixing(module):
train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain))
nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][:].tolist()) for i in range(_nd)]

logger.info("Design models with seed: %d" % torch.initial_seed())
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)

fine_tune_m = cnfg.fine_tune_m

mymodel = init_model_params(mymodel)
mymodel.apply(init_fixing)
if fine_tune_m is not None:
Expand Down Expand Up @@ -267,13 +257,10 @@ def load_fixing(module):
optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams)
optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none)

fine_tune_state = cnfg.fine_tune_state
if fine_tune_state is not None:
logger.info("Load optimizer state from: " + fine_tune_state)
optimizer.load_state_dict(h5load(fine_tune_state))

lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)

state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)})

num_checkpoint = cnfg.num_checkpoint
cur_checkid = 0

Expand All @@ -286,15 +273,22 @@ def load_fixing(module):
save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info)
logger.info("Initial model saved")
else:
cnt_states = cnfg.train_statesf
if cnt_states is not None:
logger.info("Continue last epoch")
tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler)
logger.info("Loading training states")
_remain_states = state_holder.load_state_dict(torch.load(cnt_states))
remain_steps, cur_checkid = _remain_states["remain_steps"], _remain_states["checkpoint_id"]
if "training_list" in _remain_states:
_ctl = _remain_states["training_list"]
else:
shuffle(tl)
_ctl = tl
tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, _ctl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler)
_ctl = _remain_states = None
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,))
save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
if save_optm_state:
h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,))
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("New best model saved")

if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0:
Expand All @@ -319,14 +313,14 @@ def load_fixing(module):
else:
tl = train_sampler.generate()
free_cache(use_cuda)
terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler)
terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler)
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,))

if (vprec <= minerr) or (vloss <= minloss):
save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None)
if save_optm_state:
h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,))
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("New best model saved")

namin = 0
Expand All @@ -340,15 +334,18 @@ def load_fixing(module):
if terr < tminerr:
tminerr = terr
save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
if save_optm_state:
h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,))
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
elif epoch_save:
save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)

namin += 1
if namin >= earlystop:
if done_tokens > 0:
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
lrsch.step()
done_tokens = 0
logger.info("early stop")
break
Expand All @@ -368,10 +365,11 @@ def load_fixing(module):

if done_tokens > 0:
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
lrsch.step()

save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info)
if save_optm_state:
h5save(optimizer.state_dict(), wkdir + "last.optm.h5")
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("model saved")

td.close()
Expand Down
Loading

0 comments on commit eb77366

Please sign in to comment.