33
33
34
34
import torch
35
35
from torch .utils .data import DataLoader
36
- from torch .autograd import Variable
37
- from torch .nn .parameter import Parameter
38
36
39
37
import torch .distributed as dist
40
38
from torch .utils .data .distributed import DistributedSampler
49
47
import dllogger as DLLogger
50
48
from dllogger import StdOutBackend , JSONStreamBackend , Verbosity
51
49
52
- from scipy .io .wavfile import write as write_wav
53
-
54
50
55
51
def parse_args (parser ):
56
52
"""
@@ -161,11 +157,11 @@ def parse_args(parser):
161
157
162
158
def reduce_tensor (tensor , num_gpus ):
163
159
rt = tensor .clone ()
164
- dist .all_reduce (rt , op = dist .reduce_op .SUM )
160
+ dist .all_reduce (rt , op = dist .ReduceOp .SUM )
165
161
if rt .is_floating_point ():
166
162
rt = rt / num_gpus
167
163
else :
168
- rt = rt // num_gpus
164
+ rt = torch . div ( rt , num_gpus , rounding_mode = 'floor' )
169
165
return rt
170
166
171
167
@@ -184,8 +180,8 @@ def init_distributed(args, world_size, rank, group_name):
184
180
print ("Done initializing distributed" )
185
181
186
182
187
- def save_checkpoint (model , optimizer , epoch , config , amp_run , output_dir , model_name ,
188
- local_rank , world_size ):
183
+ def save_checkpoint (model , optimizer , scaler , epoch , config , output_dir ,
184
+ model_name , local_rank , world_size ):
189
185
190
186
random_rng_state = torch .random .get_rng_state ().cuda ()
191
187
cuda_rng_state = torch .cuda .get_rng_state (local_rank ).cuda ()
@@ -209,7 +205,8 @@ def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_
209
205
'random_rng_states_all' : random_rng_states_all ,
210
206
'config' : config ,
211
207
'state_dict' : model .state_dict (),
212
- 'optimizer' : optimizer .state_dict ()}
208
+ 'optimizer' : optimizer .state_dict (),
209
+ 'scaler' : scaler .state_dict ()}
213
210
214
211
checkpoint_filename = "checkpoint_{}_{}.pt" .format (model_name , epoch )
215
212
checkpoint_path = os .path .join (output_dir , checkpoint_filename )
@@ -237,7 +234,7 @@ def get_last_checkpoint_filename(output_dir, model_name):
237
234
return ""
238
235
239
236
240
- def load_checkpoint (model , optimizer , epoch , config , amp_run , filepath , local_rank ):
237
+ def load_checkpoint (model , optimizer , scaler , epoch , filepath , local_rank ):
241
238
242
239
checkpoint = torch .load (filepath , map_location = 'cpu' )
243
240
@@ -250,9 +247,10 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
250
247
torch .random .set_rng_state (checkpoint ['random_rng_state' ])
251
248
else :
252
249
raise Exception ("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key." )
253
- config = checkpoint ['config' ]
254
250
model .load_state_dict (checkpoint ['state_dict' ])
255
251
optimizer .load_state_dict (checkpoint ['optimizer' ])
252
+ scaler .load_state_dict (checkpoint ['scaler' ])
253
+ return checkpoint ['config' ]
256
254
257
255
258
256
# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
@@ -271,7 +269,7 @@ def evaluating(model):
271
269
272
270
273
271
def validate (model , criterion , valset , epoch , batch_iter , batch_size ,
274
- world_size , collate_fn , distributed_run , rank , batch_to_gpu ):
272
+ world_size , collate_fn , distributed_run , rank , batch_to_gpu , amp_run ):
275
273
"""Handles all the validation scoring and printing"""
276
274
with evaluating (model ), torch .no_grad ():
277
275
val_sampler = DistributedSampler (valset ) if distributed_run else None
@@ -288,8 +286,11 @@ def validate(model, criterion, valset, epoch, batch_iter, batch_size,
288
286
iter_start_time = time .perf_counter ()
289
287
290
288
x , y , num_items = batch_to_gpu (batch )
291
- y_pred = model (x )
292
- loss = criterion (y_pred , y )
289
+ #AMP upstream autocast
290
+ with torch .cuda .amp .autocast (enabled = amp_run ):
291
+ y_pred = model (x )
292
+ loss = criterion (y_pred , y )
293
+
293
294
if distributed_run :
294
295
reduced_val_loss = reduce_tensor (loss .data , world_size ).item ()
295
296
reduced_num_items = reduce_tensor (num_items .data , 1 ).item ()
@@ -398,9 +399,9 @@ def main():
398
399
if args .resume_from_last :
399
400
args .checkpoint_path = get_last_checkpoint_filename (args .output , model_name )
400
401
401
- if args .checkpoint_path is not "" :
402
- load_checkpoint (model , optimizer , start_epoch , model_config ,
403
- args . amp , args .checkpoint_path , local_rank )
402
+ if args .checkpoint_path != "" :
403
+ model_config = load_checkpoint (model , optimizer , scaler , start_epoch ,
404
+ args .checkpoint_path , local_rank )
404
405
405
406
start_epoch = start_epoch [0 ]
406
407
@@ -450,9 +451,6 @@ def main():
450
451
num_iters = 0
451
452
reduced_loss = 0
452
453
453
- # if overflow at the last iteration then do not save checkpoint
454
- overflow = False
455
-
456
454
if distributed_run :
457
455
train_loader .sampler .set_epoch (epoch )
458
456
@@ -492,13 +490,13 @@ def main():
492
490
if args .amp :
493
491
scaler .scale (loss ).backward ()
494
492
scaler .unscale_ (optimizer )
495
- grad_norm = torch .nn .utils .clip_grad_norm_ (
493
+ torch .nn .utils .clip_grad_norm_ (
496
494
model .parameters (), args .grad_clip_thresh )
497
495
scaler .step (optimizer )
498
496
scaler .update ()
499
497
else :
500
498
loss .backward ()
501
- grad_norm = torch .nn .utils .clip_grad_norm_ (
499
+ torch .nn .utils .clip_grad_norm_ (
502
500
model .parameters (), args .grad_clip_thresh )
503
501
optimizer .step ()
504
502
@@ -527,12 +525,12 @@ def main():
527
525
iteration , args .batch_size ,
528
526
world_size , collate_fn ,
529
527
distributed_run , local_rank ,
530
- batch_to_gpu )
528
+ batch_to_gpu ,
529
+ args .amp )
531
530
532
531
if (epoch % args .epochs_per_checkpoint == 0 ) and args .bench_class == "" :
533
- save_checkpoint (model , optimizer , epoch , model_config ,
534
- args .amp , args .output , args .model_name ,
535
- local_rank , world_size )
532
+ save_checkpoint (model , optimizer , scaler , epoch , model_config ,
533
+ args .output , args .model_name , local_rank , world_size )
536
534
if local_rank == 0 :
537
535
DLLogger .flush ()
538
536
@@ -548,5 +546,6 @@ def main():
548
546
if local_rank == 0 :
549
547
DLLogger .flush ()
550
548
549
+
551
550
if __name__ == '__main__' :
552
551
main ()
0 commit comments