3333
3434import torch
3535from torch .utils .data import DataLoader
36- from torch .autograd import Variable
37- from torch .nn .parameter import Parameter
3836
3937import torch .distributed as dist
4038from torch .utils .data .distributed import DistributedSampler
4947import dllogger as DLLogger
5048from dllogger import StdOutBackend , JSONStreamBackend , Verbosity
5149
52- from scipy .io .wavfile import write as write_wav
53-
5450
5551def parse_args (parser ):
5652 """
@@ -161,11 +157,11 @@ def parse_args(parser):
161157
162158def reduce_tensor (tensor , num_gpus ):
163159 rt = tensor .clone ()
164- dist .all_reduce (rt , op = dist .reduce_op .SUM )
160+ dist .all_reduce (rt , op = dist .ReduceOp .SUM )
165161 if rt .is_floating_point ():
166162 rt = rt / num_gpus
167163 else :
168- rt = rt // num_gpus
164+ rt = torch . div ( rt , num_gpus , rounding_mode = 'floor' )
169165 return rt
170166
171167
@@ -184,8 +180,8 @@ def init_distributed(args, world_size, rank, group_name):
184180 print ("Done initializing distributed" )
185181
186182
187- def save_checkpoint (model , optimizer , epoch , config , amp_run , output_dir , model_name ,
188- local_rank , world_size ):
183+ def save_checkpoint (model , optimizer , scaler , epoch , config , output_dir ,
184+ model_name , local_rank , world_size ):
189185
190186 random_rng_state = torch .random .get_rng_state ().cuda ()
191187 cuda_rng_state = torch .cuda .get_rng_state (local_rank ).cuda ()
@@ -209,7 +205,8 @@ def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_
209205 'random_rng_states_all' : random_rng_states_all ,
210206 'config' : config ,
211207 'state_dict' : model .state_dict (),
212- 'optimizer' : optimizer .state_dict ()}
208+ 'optimizer' : optimizer .state_dict (),
209+ 'scaler' : scaler .state_dict ()}
213210
214211 checkpoint_filename = "checkpoint_{}_{}.pt" .format (model_name , epoch )
215212 checkpoint_path = os .path .join (output_dir , checkpoint_filename )
@@ -237,7 +234,7 @@ def get_last_checkpoint_filename(output_dir, model_name):
237234 return ""
238235
239236
240- def load_checkpoint (model , optimizer , epoch , config , amp_run , filepath , local_rank ):
237+ def load_checkpoint (model , optimizer , scaler , epoch , filepath , local_rank ):
241238
242239 checkpoint = torch .load (filepath , map_location = 'cpu' )
243240
@@ -250,9 +247,10 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
250247 torch .random .set_rng_state (checkpoint ['random_rng_state' ])
251248 else :
252249 raise Exception ("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key." )
253- config = checkpoint ['config' ]
254250 model .load_state_dict (checkpoint ['state_dict' ])
255251 optimizer .load_state_dict (checkpoint ['optimizer' ])
252+ scaler .load_state_dict (checkpoint ['scaler' ])
253+ return checkpoint ['config' ]
256254
257255
258256# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
@@ -271,7 +269,7 @@ def evaluating(model):
271269
272270
273271def validate (model , criterion , valset , epoch , batch_iter , batch_size ,
274- world_size , collate_fn , distributed_run , rank , batch_to_gpu ):
272+ world_size , collate_fn , distributed_run , rank , batch_to_gpu , amp_run ):
275273 """Handles all the validation scoring and printing"""
276274 with evaluating (model ), torch .no_grad ():
277275 val_sampler = DistributedSampler (valset ) if distributed_run else None
@@ -288,8 +286,11 @@ def validate(model, criterion, valset, epoch, batch_iter, batch_size,
288286 iter_start_time = time .perf_counter ()
289287
290288 x , y , num_items = batch_to_gpu (batch )
291- y_pred = model (x )
292- loss = criterion (y_pred , y )
289+ #AMP upstream autocast
290+ with torch .cuda .amp .autocast (enabled = amp_run ):
291+ y_pred = model (x )
292+ loss = criterion (y_pred , y )
293+
293294 if distributed_run :
294295 reduced_val_loss = reduce_tensor (loss .data , world_size ).item ()
295296 reduced_num_items = reduce_tensor (num_items .data , 1 ).item ()
@@ -398,9 +399,9 @@ def main():
398399 if args .resume_from_last :
399400 args .checkpoint_path = get_last_checkpoint_filename (args .output , model_name )
400401
401- if args .checkpoint_path is not "" :
402- load_checkpoint (model , optimizer , start_epoch , model_config ,
403- args . amp , args .checkpoint_path , local_rank )
402+ if args .checkpoint_path != "" :
403+ model_config = load_checkpoint (model , optimizer , scaler , start_epoch ,
404+ args .checkpoint_path , local_rank )
404405
405406 start_epoch = start_epoch [0 ]
406407
@@ -450,9 +451,6 @@ def main():
450451 num_iters = 0
451452 reduced_loss = 0
452453
453- # if overflow at the last iteration then do not save checkpoint
454- overflow = False
455-
456454 if distributed_run :
457455 train_loader .sampler .set_epoch (epoch )
458456
@@ -492,13 +490,13 @@ def main():
492490 if args .amp :
493491 scaler .scale (loss ).backward ()
494492 scaler .unscale_ (optimizer )
495- grad_norm = torch .nn .utils .clip_grad_norm_ (
493+ torch .nn .utils .clip_grad_norm_ (
496494 model .parameters (), args .grad_clip_thresh )
497495 scaler .step (optimizer )
498496 scaler .update ()
499497 else :
500498 loss .backward ()
501- grad_norm = torch .nn .utils .clip_grad_norm_ (
499+ torch .nn .utils .clip_grad_norm_ (
502500 model .parameters (), args .grad_clip_thresh )
503501 optimizer .step ()
504502
@@ -527,12 +525,12 @@ def main():
527525 iteration , args .batch_size ,
528526 world_size , collate_fn ,
529527 distributed_run , local_rank ,
530- batch_to_gpu )
528+ batch_to_gpu ,
529+ args .amp )
531530
532531 if (epoch % args .epochs_per_checkpoint == 0 ) and args .bench_class == "" :
533- save_checkpoint (model , optimizer , epoch , model_config ,
534- args .amp , args .output , args .model_name ,
535- local_rank , world_size )
532+ save_checkpoint (model , optimizer , scaler , epoch , model_config ,
533+ args .output , args .model_name , local_rank , world_size )
536534 if local_rank == 0 :
537535 DLLogger .flush ()
538536
@@ -548,5 +546,6 @@ def main():
548546 if local_rank == 0 :
549547 DLLogger .flush ()
550548
549+
551550if __name__ == '__main__' :
552551 main ()
0 commit comments