1
1
import gc
2
2
import os
3
- import random
3
+ import shutil
4
4
import sys
5
5
import time
6
6
import warnings
7
7
from functools import partial
8
8
9
- import numpy as np
10
9
import torch
11
10
from torch .utils .data import DataLoader
12
11
13
12
import dist
14
- from utils .misc import auto_resume
15
13
from utils import arg_util , misc
16
14
from utils .data import build_dataset
17
15
from utils .data_sampler import DistInfiniteBatchSampler , EvalDistributedSampler
16
+ from utils .misc import auto_resume
18
17
19
18
20
19
def build_everything (args : arg_util .Args ):
21
20
# resume
22
21
auto_resume_info , start_ep , start_it , trainer_state , args_state = auto_resume (args , 'ar-ckpt*.pth' )
23
22
# create tensorboard logger
24
23
tb_lg : misc .TensorboardLogger
25
- with_tb_lg = dist .is_visualizer ()
24
+ with_tb_lg = dist .is_master ()
26
25
if with_tb_lg :
27
26
os .makedirs (args .tb_log_dir_path , exist_ok = True )
28
27
# noinspection PyTypeChecker
@@ -130,7 +129,7 @@ def build_everything(args: arg_util.Args):
130
129
131
130
# build trainer
132
131
trainer = VARTrainer (
133
- is_visualizer = dist . is_visualizer (), device = args .device , patch_nums = args .patch_nums , resos = args .resos ,
132
+ device = args .device , patch_nums = args .patch_nums , resos = args .resos ,
134
133
vae_local = vae_local , var_wo_ddp = var_wo_ddp , var = var ,
135
134
var_opt = var_optim , label_smooth = args .ls ,
136
135
)
@@ -157,7 +156,7 @@ def build_everything(args: arg_util.Args):
157
156
)
158
157
print ({k : meter .global_avg for k , meter in me .meters .items ()})
159
158
160
- tb_lg .flush (); tb_lg .close ()
159
+ args . dump_log (); tb_lg .flush (); tb_lg .close ()
161
160
if isinstance (sys .stdout , misc .SyncPrint ) and isinstance (sys .stderr , misc .SyncPrint ):
162
161
sys .stdout .close (), sys .stderr .close ()
163
162
exit (0 )
@@ -169,7 +168,7 @@ def build_everything(args: arg_util.Args):
169
168
)
170
169
171
170
172
- def main ():
171
+ def main_training ():
173
172
args : arg_util .Args = arg_util .init_dist_and_get_args ()
174
173
if args .local_debug :
175
174
torch .autograd .set_detect_anomaly (True )
@@ -181,9 +180,9 @@ def main():
181
180
) = build_everything (args )
182
181
183
182
# train
184
- start_time , min_L_mean , min_L_tail , max_acc_mean , max_acc_tail = time .time (), 999. , 999. , - 1. , - 1.
185
- last_val_loss_mean , best_val_loss_mean , last_val_acc_mean , best_val_acc_mean = 999 , 999 , 0 , 0
186
- last_val_loss_tail , best_val_loss_tail , last_val_acc_tail , best_val_acc_tail = 999 , 999 , 0 , 0
183
+ start_time = time .time ()
184
+ best_L_mean , best_L_tail , best_acc_mean , best_acc_tail = 999. , 999. , - 1. , - 1.
185
+ best_val_loss_mean , best_val_loss_tail , best_val_acc_mean , best_val_acc_tail = 999 , 999 , - 1 , - 1
187
186
188
187
L_mean , L_tail = - 1 , - 1
189
188
for ep in range (start_ep , args .ep ):
@@ -199,49 +198,46 @@ def main():
199
198
)
200
199
201
200
L_mean , L_tail , acc_mean , acc_tail , grad_norm = stats ['Lm' ], stats ['Lt' ], stats ['Accm' ], stats ['Acct' ], stats ['tnm' ]
202
- min_L_mean , max_acc_mean , max_acc_tail = min (min_L_mean , L_mean ), max (max_acc_mean , acc_mean ), max (max_acc_tail , acc_tail )
203
- if L_tail != - 1 :
204
- min_L_tail = min (min_L_tail , L_tail )
205
- args .min_L_mean , args .min_L_tail , args .max_acc_mean , args .max_acc_tail , args .grad_norm = min_L_mean , min_L_tail , (None if max_acc_mean < 0 else max_acc_mean ), (None if max_acc_tail < 0 else max_acc_tail ), grad_norm
201
+ best_L_mean , best_acc_mean = min (best_L_mean , L_mean ), max (best_acc_mean , acc_mean )
202
+ if L_tail != - 1 : best_L_tail , best_acc_tail = min (best_L_tail , L_tail ), max (best_acc_tail , acc_tail )
203
+ args .L_mean , args .L_tail , args .acc_mean , args .acc_tail , args .grad_norm = L_mean , L_tail , acc_mean , acc_tail , grad_norm
206
204
args .cur_ep = f'{ ep + 1 } /{ args .ep } '
207
205
args .remain_time , args .finish_time = remain_time , finish_time
208
206
209
- AR_ep_loss = {}
207
+ AR_ep_loss = dict ( L_mean = L_mean , L_tail = L_tail , acc_mean = acc_mean , acc_tail = acc_tail )
210
208
is_val_and_also_saving = (ep + 1 ) % 10 == 0 or (ep + 1 ) == args .ep
211
209
if is_val_and_also_saving :
212
- last_val_loss_mean , last_val_loss_tail , last_val_acc_mean , last_val_acc_tail , tot , cost = trainer .eval_ep (ld_val )
213
- best_val_loss_mean , best_val_loss_tail = min (best_val_loss_mean , last_val_loss_mean ), min (best_val_loss_tail , last_val_loss_tail )
214
- best_val_acc_mean , best_val_acc_tail = max (best_val_acc_mean , last_val_acc_mean ), max (best_val_acc_tail , last_val_acc_tail )
215
- AR_ep_loss ['vL_mean' ], AR_ep_loss ['vL_tail' ], AR_ep_loss ['vacc_mean' ], AR_ep_loss ['vacc_tail' ] = last_val_loss_mean , last_val_loss_tail , last_val_acc_mean , last_val_acc_tail
210
+ val_loss_mean , val_loss_tail , val_acc_mean , val_acc_tail , tot , cost = trainer .eval_ep (ld_val )
211
+ best_updated = best_val_loss_tail > val_loss_tail
212
+ best_val_loss_mean , best_val_loss_tail = min (best_val_loss_mean , val_loss_mean ), min (best_val_loss_tail , val_loss_tail )
213
+ best_val_acc_mean , best_val_acc_tail = max (best_val_acc_mean , val_acc_mean ), max (best_val_acc_tail , val_acc_tail )
214
+ AR_ep_loss .update (vL_mean = val_loss_mean , vL_tail = val_loss_tail , vacc_mean = val_acc_mean , vacc_tail = val_acc_tail )
215
+ args .vL_mean , args .vL_tail , args .vacc_mean , args .vacc_tail = val_loss_mean , val_loss_tail , val_acc_mean , val_acc_tail
216
216
print (f' [*] [ep{ ep } ] (val { tot } ) Lm: { L_mean :.4f} , Lt: { L_tail :.4f} , Acc m&t: { acc_mean :.2f} { acc_tail :.2f} , Val cost: { cost :.2f} s' )
217
+
218
+ if dist .is_local_master ():
219
+ local_out_ckpt = os .path .join (args .local_out_dir_path , 'ar-ckpt-last.pth' )
220
+ local_out_ckpt_best = os .path .join (args .local_out_dir_path , 'ar-ckpt-best.pth' )
221
+ print (f'[saving ckpt] ...' , end = '' , flush = True )
222
+ torch .save ({
223
+ 'epoch' : ep + 1 ,
224
+ 'iter' : 0 ,
225
+ 'trainer' : trainer .state_dict (),
226
+ 'args' : args .state_dict (),
227
+ }, local_out_ckpt )
228
+ if best_updated :
229
+ shutil .copy (local_out_ckpt , local_out_ckpt_best )
230
+ print (f' [saving ckpt](*) finished! @ { local_out_ckpt } ' , flush = True , clean = True )
231
+ dist .barrier ()
217
232
218
- print ( f' [ep{ ep } ] (training ) Lm: { min_L_mean :.3f} ({ L_mean :.3f} ), Lt: { min_L_tail :.3f} ({ L_tail :.3f} ), Acc m&t: { max_acc_mean :.2f} { max_acc_tail :.2f} , Remain: { remain_time } , Finish: { finish_time } ' , flush = True )
219
- if ep > args .ep // 20 and min_L_tail < 99 :
220
- tb_lg .update (head = 'AR_y_result' , step = ep + 1 , min_L_mean = min_L_mean , min_L_tail = min_L_tail , max_acc_mean = max_acc_mean , max_acc_tail = max_acc_tail )
221
-
222
- AR_ep_loss ['L_mean' ], AR_ep_loss ['L_tail' ], AR_ep_loss ['acc_mean' ], AR_ep_loss ['acc_tail' ] = L_mean , L_tail , acc_mean , acc_tail
233
+ print ( f' [ep{ ep } ] (training ) Lm: { best_L_mean :.3f} ({ L_mean :.3f} ), Lt: { best_L_tail :.3f} ({ L_tail :.3f} ), Acc m&t: { best_acc_mean :.2f} { best_acc_tail :.2f} , Remain: { remain_time } , Finish: { finish_time } ' , flush = True )
223
234
tb_lg .update (head = 'AR_ep_loss' , step = ep + 1 , ** AR_ep_loss )
224
235
tb_lg .update (head = 'AR_z_burnout' , step = ep + 1 , rest_hours = round (sec / 60 / 60 , 2 ))
225
-
226
- if is_val_and_also_saving and dist .is_master ():
227
- local_out_ckpt = os .path .join (args .local_out_dir_path , 'ar-ckpt-last.pth' )
228
- torch .save ({
229
- 'epoch' : ep + 1 ,
230
- 'iter' : 0 ,
231
- 'trainer' : trainer .state_dict (),
232
- 'args' : args .state_dict (),
233
- }, local_out_ckpt )
234
-
235
- tb_lg .flush ()
236
- dist .barrier ()
237
-
238
- tb_lg .update (head = 'AR_y_result_final' , step = start_ep , min_L_mean = min_L_mean , min_L_tail = min_L_tail , max_acc_mean = max_acc_mean , max_acc_tail = max_acc_tail )
239
- tb_lg .update (head = 'AR_y_result_final' , step = args .ep , min_L_mean = min_L_mean , min_L_tail = min_L_tail , max_acc_mean = max_acc_mean , max_acc_tail = max_acc_tail )
240
- tb_lg .flush ()
236
+ args .dump_log (); tb_lg .flush ()
241
237
242
238
total_time = f'{ (time .time () - start_time ) / 60 / 60 :.1f} h'
243
239
print ('\n \n ' )
244
- print (f' [*] [PT finished] Total Time : { total_time } , Lm: { min_L_mean :.3f} ({ L_mean } ), Lt: { min_L_tail :.3f} ({ L_tail } )' )
240
+ print (f' [*] [PT finished] Total cost : { total_time } , Lm: { best_L_mean :.3f} ({ L_mean } ), Lt: { best_L_tail :.3f} ({ L_tail } )' )
245
241
print ('\n \n ' )
246
242
247
243
del stats
@@ -250,7 +246,7 @@ def main():
250
246
251
247
args .remain_time , args .finish_time = '-' , time .strftime ("%Y-%m-%d %H:%M" , time .localtime (time .time () - 60 ))
252
248
print (f'final args:\n \n { str (args )} ' )
253
- tb_lg .flush (); tb_lg .close ()
249
+ args . dump_log (); tb_lg .flush (); tb_lg .close ()
254
250
dist .barrier ()
255
251
256
252
@@ -285,6 +281,7 @@ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args,
285
281
286
282
wp_it = args .wp * iters_train
287
283
min_tlr , max_tlr , min_twd , max_twd = lr_wd_annealing (args .sche , trainer .var_opt .optimizer , args .tlr , args .twd , args .twde , g_it , wp_it , max_it , wp0 = args .wp0 , wpe = args .wpe )
284
+ args .cur_lr , args .cur_wd = max_tlr , max_twd
288
285
289
286
if args .pg : # default: 0.0, no progressive training, won't get into this
290
287
if g_it <= wp_it : prog_si = args .pg0
@@ -310,8 +307,7 @@ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args,
310
307
tb_lg .update (head = 'AR_opt_lr/lr_max' , sche_tlr = max_tlr )
311
308
tb_lg .update (head = 'AR_opt_wd/wd_max' , sche_twd = max_twd )
312
309
tb_lg .update (head = 'AR_opt_wd/wd_min' , sche_twd = min_twd )
313
- if scale_log2 is not None :
314
- tb_lg .update (head = 'AR_opt_grad/fp16' , scale_log2 = scale_log2 )
310
+ tb_lg .update (head = 'AR_opt_grad/fp16' , scale_log2 = scale_log2 )
315
311
316
312
if args .tclip > 0 :
317
313
tb_lg .update (head = 'AR_opt_grad/grad' , grad_norm = grad_norm )
@@ -335,18 +331,7 @@ def forward(self, *args, **kwargs):
335
331
336
332
337
333
if __name__ == '__main__' :
338
- try :
339
- main ()
340
- except Exception as err :
341
- time .sleep (dist .get_rank () * 1 + random .random () * 0.5 )
342
- try :
343
- # noinspection PyArgumentList
344
- print (f'[rk{ dist .get_rank ():2d} ] { type (err ).__name__ } ' , flush = True , force = True )
345
- except :
346
- try : print (f'[rk{ dist .get_rank ():2d} ] { type (err ).__name__ } ' , flush = True )
347
- except : pass
348
- if dist .is_master (): print (f'[err]:\n { err } ' )
349
- raise err
334
+ try : main_training ()
350
335
finally :
351
336
dist .finalize ()
352
337
if isinstance (sys .stdout , misc .SyncPrint ) and isinstance (sys .stderr , misc .SyncPrint ):
0 commit comments