Skip to content

Commit 64d3aee

Browse files
committed
[upd] simplify codes & add more logs
1 parent 8198567 commit 64d3aee

File tree

5 files changed

+89
-88
lines changed

5 files changed

+89
-88
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --mas
105105
--depth=30 --bs=1024 --ep=350 --fp16=1 --alng=1e-5 --wpe=0.01
106106
```
107107
A folder named `local_output` will be created to save the checkpoints and logs.
108-
You can monitor the training process by checking the logs in `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
108+
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
109109

110110
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
111111

dist.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ def is_local_master():
8383
return __local_rank == 0
8484

8585

86-
def is_visualizer():
87-
return __rank == 0
88-
89-
9086
def new_group(ranks: List[int]):
9187
if __initialized:
9288
return tdist.new_group(ranks=ranks)
@@ -201,7 +197,7 @@ def wrapper(*args, **kwargs):
201197
def for_visualize(func):
202198
@functools.wraps(func)
203199
def wrapper(*args, **kwargs):
204-
if is_visualizer():
200+
if is_master():
205201
# with torch.no_grad():
206202
ret = func(*args, **kwargs)
207203
else:

train.py

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
import gc
22
import os
3-
import random
3+
import shutil
44
import sys
55
import time
66
import warnings
77
from functools import partial
88

9-
import numpy as np
109
import torch
1110
from torch.utils.data import DataLoader
1211

1312
import dist
14-
from utils.misc import auto_resume
1513
from utils import arg_util, misc
1614
from utils.data import build_dataset
1715
from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
16+
from utils.misc import auto_resume
1817

1918

2019
def build_everything(args: arg_util.Args):
2120
# resume
2221
auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
2322
# create tensorboard logger
2423
tb_lg: misc.TensorboardLogger
25-
with_tb_lg = dist.is_visualizer()
24+
with_tb_lg = dist.is_master()
2625
if with_tb_lg:
2726
os.makedirs(args.tb_log_dir_path, exist_ok=True)
2827
# noinspection PyTypeChecker
@@ -130,7 +129,7 @@ def build_everything(args: arg_util.Args):
130129

131130
# build trainer
132131
trainer = VARTrainer(
133-
is_visualizer=dist.is_visualizer(), device=args.device, patch_nums=args.patch_nums, resos=args.resos,
132+
device=args.device, patch_nums=args.patch_nums, resos=args.resos,
134133
vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
135134
var_opt=var_optim, label_smooth=args.ls,
136135
)
@@ -157,7 +156,7 @@ def build_everything(args: arg_util.Args):
157156
)
158157
print({k: meter.global_avg for k, meter in me.meters.items()})
159158

160-
tb_lg.flush(); tb_lg.close()
159+
args.dump_log(); tb_lg.flush(); tb_lg.close()
161160
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
162161
sys.stdout.close(), sys.stderr.close()
163162
exit(0)
@@ -169,7 +168,7 @@ def build_everything(args: arg_util.Args):
169168
)
170169

171170

172-
def main():
171+
def main_training():
173172
args: arg_util.Args = arg_util.init_dist_and_get_args()
174173
if args.local_debug:
175174
torch.autograd.set_detect_anomaly(True)
@@ -181,9 +180,9 @@ def main():
181180
) = build_everything(args)
182181

183182
# train
184-
start_time, min_L_mean, min_L_tail, max_acc_mean, max_acc_tail = time.time(), 999., 999., -1., -1.
185-
last_val_loss_mean, best_val_loss_mean, last_val_acc_mean, best_val_acc_mean = 999, 999, 0, 0
186-
last_val_loss_tail, best_val_loss_tail, last_val_acc_tail, best_val_acc_tail = 999, 999, 0, 0
183+
start_time = time.time()
184+
best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
185+
best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
187186

188187
L_mean, L_tail = -1, -1
189188
for ep in range(start_ep, args.ep):
@@ -199,49 +198,46 @@ def main():
199198
)
200199

201200
L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
202-
min_L_mean, max_acc_mean, max_acc_tail = min(min_L_mean, L_mean), max(max_acc_mean, acc_mean), max(max_acc_tail, acc_tail)
203-
if L_tail != -1:
204-
min_L_tail = min(min_L_tail, L_tail)
205-
args.min_L_mean, args.min_L_tail, args.max_acc_mean, args.max_acc_tail, args.grad_norm = min_L_mean, min_L_tail, (None if max_acc_mean < 0 else max_acc_mean), (None if max_acc_tail < 0 else max_acc_tail), grad_norm
201+
best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
202+
if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
203+
args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
206204
args.cur_ep = f'{ep+1}/{args.ep}'
207205
args.remain_time, args.finish_time = remain_time, finish_time
208206

209-
AR_ep_loss = {}
207+
AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
210208
is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
211209
if is_val_and_also_saving:
212-
last_val_loss_mean, last_val_loss_tail, last_val_acc_mean, last_val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
213-
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, last_val_loss_mean), min(best_val_loss_tail, last_val_loss_tail)
214-
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, last_val_acc_mean), max(best_val_acc_tail, last_val_acc_tail)
215-
AR_ep_loss['vL_mean'], AR_ep_loss['vL_tail'], AR_ep_loss['vacc_mean'], AR_ep_loss['vacc_tail'] = last_val_loss_mean, last_val_loss_tail, last_val_acc_mean, last_val_acc_tail
210+
val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
211+
best_updated = best_val_loss_tail > val_loss_tail
212+
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
213+
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
214+
AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
215+
args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
216216
print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
217+
218+
if dist.is_local_master():
219+
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
220+
local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
221+
print(f'[saving ckpt] ...', end='', flush=True)
222+
torch.save({
223+
'epoch': ep+1,
224+
'iter': 0,
225+
'trainer': trainer.state_dict(),
226+
'args': args.state_dict(),
227+
}, local_out_ckpt)
228+
if best_updated:
229+
shutil.copy(local_out_ckpt, local_out_ckpt_best)
230+
print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
231+
dist.barrier()
217232

218-
print( f' [ep{ep}] (training ) Lm: {min_L_mean:.3f} ({L_mean:.3f}), Lt: {min_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {max_acc_mean:.2f} {max_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
219-
if ep > args.ep // 20 and min_L_tail < 99:
220-
tb_lg.update(head='AR_y_result', step=ep+1, min_L_mean=min_L_mean, min_L_tail=min_L_tail, max_acc_mean=max_acc_mean, max_acc_tail=max_acc_tail)
221-
222-
AR_ep_loss['L_mean'], AR_ep_loss['L_tail'], AR_ep_loss['acc_mean'], AR_ep_loss['acc_tail'] = L_mean, L_tail, acc_mean, acc_tail
233+
print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
223234
tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
224235
tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
225-
226-
if is_val_and_also_saving and dist.is_master():
227-
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
228-
torch.save({
229-
'epoch': ep+1,
230-
'iter': 0,
231-
'trainer': trainer.state_dict(),
232-
'args': args.state_dict(),
233-
}, local_out_ckpt)
234-
235-
tb_lg.flush()
236-
dist.barrier()
237-
238-
tb_lg.update(head='AR_y_result_final', step=start_ep, min_L_mean=min_L_mean, min_L_tail=min_L_tail, max_acc_mean=max_acc_mean, max_acc_tail=max_acc_tail)
239-
tb_lg.update(head='AR_y_result_final', step=args.ep, min_L_mean=min_L_mean, min_L_tail=min_L_tail, max_acc_mean=max_acc_mean, max_acc_tail=max_acc_tail)
240-
tb_lg.flush()
236+
args.dump_log(); tb_lg.flush()
241237

242238
total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
243239
print('\n\n')
244-
print(f' [*] [PT finished] Total Time: {total_time}, Lm: {min_L_mean:.3f} ({L_mean}), Lt: {min_L_tail:.3f} ({L_tail})')
240+
print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
245241
print('\n\n')
246242

247243
del stats
@@ -250,7 +246,7 @@ def main():
250246

251247
args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
252248
print(f'final args:\n\n{str(args)}')
253-
tb_lg.flush(); tb_lg.close()
249+
args.dump_log(); tb_lg.flush(); tb_lg.close()
254250
dist.barrier()
255251

256252

@@ -285,6 +281,7 @@ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args,
285281

286282
wp_it = args.wp * iters_train
287283
min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
284+
args.cur_lr, args.cur_wd = max_tlr, max_twd
288285

289286
if args.pg: # default: 0.0, no progressive training, won't get into this
290287
if g_it <= wp_it: prog_si = args.pg0
@@ -310,8 +307,7 @@ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args,
310307
tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
311308
tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
312309
tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
313-
if scale_log2 is not None:
314-
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
310+
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
315311

316312
if args.tclip > 0:
317313
tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
@@ -335,18 +331,7 @@ def forward(self, *args, **kwargs):
335331

336332

337333
if __name__ == '__main__':
338-
try:
339-
main()
340-
except Exception as err:
341-
time.sleep(dist.get_rank() * 1 + random.random() * 0.5)
342-
try:
343-
# noinspection PyArgumentList
344-
print(f'[rk{dist.get_rank():2d}] {type(err).__name__}', flush=True, force=True)
345-
except:
346-
try: print(f'[rk{dist.get_rank():2d}] {type(err).__name__}', flush=True)
347-
except: pass
348-
if dist.is_master(): print(f'[err]:\n{err}')
349-
raise err
334+
try: main_training()
350335
finally:
351336
dist.finalize()
352337
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):

trainer.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class VARTrainer(object):
2121
def __init__(
22-
self, is_visualizer: bool, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
22+
self, is_master: bool, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
2323
vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
2424
var_opt: AmpOptimizer, label_smooth: float,
2525
):
@@ -30,8 +30,6 @@ def __init__(
3030
self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
3131
self.var_opt = var_opt
3232

33-
self.is_visualizer = is_visualizer
34-
3533
del self.var_wo_ddp.rng
3634
self.var_wo_ddp.rng = torch.Generator(device=device)
3735

@@ -112,12 +110,12 @@ def train_step(
112110
self.var_wo_ddp.forward
113111
logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
114112
loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
115-
if prog_si >= 0:
113+
if prog_si >= 0: # in progressive training
116114
bg, ed = self.begin_ends[prog_si]
117115
assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
118116
lw = self.loss_weight[:, :ed].clone()
119117
lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
120-
else:
118+
else: # not in progressive training
121119
lw = self.loss_weight
122120
loss = loss.mul(lw).sum(dim=-1).mean()
123121

@@ -126,33 +124,27 @@ def train_step(
126124

127125
# log
128126
pred_BL = logits_BLV.data.argmax(dim=-1)
129-
if it in metric_lg.log_iters:
127+
if it == 0 or it in metric_lg.log_iters:
130128
Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
131129
acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
132-
if prog_si < 0:
130+
if prog_si >= 0: # in progressive training
131+
Ltail = acc_tail = -1
132+
else: # not in progressive training
133133
Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
134134
acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
135-
else:
136-
Ltail = acc_tail = -1
137135
grad_norm = grad_norm.item()
138136
metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
139137

138+
# log to tensorboard
140139
if g_it == 0 or (g_it + 1) % 500 == 0:
141-
if g_it == 0:
142-
prob_per_class_is_chosen = gt_BL.view(-1).bincount(minlength=V).float()
143-
dist.allreduce(prob_per_class_is_chosen)
144-
if self.is_visualizer:
145-
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
146-
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
147-
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
148-
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
149-
150140
prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
151141
dist.allreduce(prob_per_class_is_chosen)
152-
153-
if self.is_visualizer:
154-
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
155-
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
142+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
143+
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
144+
if dist.is_master():
145+
if g_it == 0:
146+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
147+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
156148
kw = dict(z_voc_usage=cluster_usage)
157149
for si, (bg, ed) in enumerate(self.begin_ends):
158150
if 0 <= prog_si < si: break

utils/arg_util.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import random
34
import re
@@ -85,11 +86,17 @@ class Args(Tap):
8586
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
8687
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
8788
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
88-
max_acc_mean: float = None # [automatically set; don't specify this]
89-
max_acc_tail: float = None # [automatically set; don't specify this]
90-
min_L_mean: float = None # [automatically set; don't specify this]
91-
min_L_tail: float = None # [automatically set; don't specify this]
89+
acc_mean: float = None # [automatically set; don't specify this]
90+
acc_tail: float = None # [automatically set; don't specify this]
91+
L_mean: float = None # [automatically set; don't specify this]
92+
L_tail: float = None # [automatically set; don't specify this]
93+
vacc_mean: float = None # [automatically set; don't specify this]
94+
vacc_tail: float = None # [automatically set; don't specify this]
95+
vL_mean: float = None # [automatically set; don't specify this]
96+
vL_tail: float = None # [automatically set; don't specify this]
9297
grad_norm: float = None # [automatically set; don't specify this]
98+
cur_lr: float = None # [automatically set; don't specify this]
99+
cur_wd: float = None # [automatically set; don't specify this]
93100
cur_it: str = '' # [automatically set; don't specify this]
94101
cur_ep: str = '' # [automatically set; don't specify this]
95102
remain_time: str = '' # [automatically set; don't specify this]
@@ -168,6 +175,27 @@ def set_tf32(tf32: bool):
168175
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
169176
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
170177

178+
def dump_log(self):
179+
if not dist.is_local_master():
180+
return
181+
if '1/' in self.cur_ep: # first time to dump log
182+
with open(self.log_txt_path, 'w') as fp:
183+
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
184+
fp.write('\n')
185+
186+
log_dict = {}
187+
for k, v in {
188+
'it': self.cur_it, 'ep': self.cur_ep,
189+
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
190+
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
191+
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
192+
'remain_time': self.remain_time, 'finish_time': self.finish_time,
193+
}.items():
194+
if hasattr(v, 'item'): v = v.item()
195+
log_dict[k] = v
196+
with open(self.log_txt_path, 'a') as fp:
197+
fp.write(f'{log_dict}\n')
198+
171199
def __str__(self):
172200
s = []
173201
for k in self.class_variables.keys():

0 commit comments

Comments
 (0)