Skip to content

Commit 7ce1754

Browse files
alancuckinv-kkudrynski
authored andcommitted
[Tacotron2/PyT] Fix warnings, AMP state loading
1 parent eda18ea commit 7ce1754

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

PyTorch/SpeechSynthesis/Tacotron2/train.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
import torch
3535
from torch.utils.data import DataLoader
36-
from torch.autograd import Variable
37-
from torch.nn.parameter import Parameter
3836

3937
import torch.distributed as dist
4038
from torch.utils.data.distributed import DistributedSampler
@@ -49,8 +47,6 @@
4947
import dllogger as DLLogger
5048
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
5149

52-
from scipy.io.wavfile import write as write_wav
53-
5450

5551
def parse_args(parser):
5652
"""
@@ -161,11 +157,11 @@ def parse_args(parser):
161157

162158
def reduce_tensor(tensor, num_gpus):
163159
rt = tensor.clone()
164-
dist.all_reduce(rt, op=dist.reduce_op.SUM)
160+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
165161
if rt.is_floating_point():
166162
rt = rt/num_gpus
167163
else:
168-
rt = rt//num_gpus
164+
rt = torch.div(rt, num_gpus, rounding_mode='floor')
169165
return rt
170166

171167

@@ -184,8 +180,8 @@ def init_distributed(args, world_size, rank, group_name):
184180
print("Done initializing distributed")
185181

186182

187-
def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_name,
188-
local_rank, world_size):
183+
def save_checkpoint(model, optimizer, scaler, epoch, config, output_dir,
184+
model_name, local_rank, world_size):
189185

190186
random_rng_state = torch.random.get_rng_state().cuda()
191187
cuda_rng_state = torch.cuda.get_rng_state(local_rank).cuda()
@@ -209,7 +205,8 @@ def save_checkpoint(model, optimizer, epoch, config, amp_run, output_dir, model_
209205
'random_rng_states_all': random_rng_states_all,
210206
'config': config,
211207
'state_dict': model.state_dict(),
212-
'optimizer': optimizer.state_dict()}
208+
'optimizer': optimizer.state_dict(),
209+
'scaler': scaler.state_dict()}
213210

214211
checkpoint_filename = "checkpoint_{}_{}.pt".format(model_name, epoch)
215212
checkpoint_path = os.path.join(output_dir, checkpoint_filename)
@@ -237,7 +234,7 @@ def get_last_checkpoint_filename(output_dir, model_name):
237234
return ""
238235

239236

240-
def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_rank):
237+
def load_checkpoint(model, optimizer, scaler, epoch, filepath, local_rank):
241238

242239
checkpoint = torch.load(filepath, map_location='cpu')
243240

@@ -250,9 +247,10 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
250247
torch.random.set_rng_state(checkpoint['random_rng_state'])
251248
else:
252249
raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
253-
config = checkpoint['config']
254250
model.load_state_dict(checkpoint['state_dict'])
255251
optimizer.load_state_dict(checkpoint['optimizer'])
252+
scaler.load_state_dict(checkpoint['scaler'])
253+
return checkpoint['config']
256254

257255

258256
# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
@@ -271,7 +269,7 @@ def evaluating(model):
271269

272270

273271
def validate(model, criterion, valset, epoch, batch_iter, batch_size,
274-
world_size, collate_fn, distributed_run, rank, batch_to_gpu):
272+
world_size, collate_fn, distributed_run, rank, batch_to_gpu, amp_run):
275273
"""Handles all the validation scoring and printing"""
276274
with evaluating(model), torch.no_grad():
277275
val_sampler = DistributedSampler(valset) if distributed_run else None
@@ -288,8 +286,11 @@ def validate(model, criterion, valset, epoch, batch_iter, batch_size,
288286
iter_start_time = time.perf_counter()
289287

290288
x, y, num_items = batch_to_gpu(batch)
291-
y_pred = model(x)
292-
loss = criterion(y_pred, y)
289+
#AMP upstream autocast
290+
with torch.cuda.amp.autocast(enabled=amp_run):
291+
y_pred = model(x)
292+
loss = criterion(y_pred, y)
293+
293294
if distributed_run:
294295
reduced_val_loss = reduce_tensor(loss.data, world_size).item()
295296
reduced_num_items = reduce_tensor(num_items.data, 1).item()
@@ -398,9 +399,9 @@ def main():
398399
if args.resume_from_last:
399400
args.checkpoint_path = get_last_checkpoint_filename(args.output, model_name)
400401

401-
if args.checkpoint_path is not "":
402-
load_checkpoint(model, optimizer, start_epoch, model_config,
403-
args.amp, args.checkpoint_path, local_rank)
402+
if args.checkpoint_path != "":
403+
model_config = load_checkpoint(model, optimizer, scaler, start_epoch,
404+
args.checkpoint_path, local_rank)
404405

405406
start_epoch = start_epoch[0]
406407

@@ -450,9 +451,6 @@ def main():
450451
num_iters = 0
451452
reduced_loss = 0
452453

453-
# if overflow at the last iteration then do not save checkpoint
454-
overflow = False
455-
456454
if distributed_run:
457455
train_loader.sampler.set_epoch(epoch)
458456

@@ -492,13 +490,13 @@ def main():
492490
if args.amp:
493491
scaler.scale(loss).backward()
494492
scaler.unscale_(optimizer)
495-
grad_norm = torch.nn.utils.clip_grad_norm_(
493+
torch.nn.utils.clip_grad_norm_(
496494
model.parameters(), args.grad_clip_thresh)
497495
scaler.step(optimizer)
498496
scaler.update()
499497
else:
500498
loss.backward()
501-
grad_norm = torch.nn.utils.clip_grad_norm_(
499+
torch.nn.utils.clip_grad_norm_(
502500
model.parameters(), args.grad_clip_thresh)
503501
optimizer.step()
504502

@@ -527,12 +525,12 @@ def main():
527525
iteration, args.batch_size,
528526
world_size, collate_fn,
529527
distributed_run, local_rank,
530-
batch_to_gpu)
528+
batch_to_gpu,
529+
args.amp)
531530

532531
if (epoch % args.epochs_per_checkpoint == 0) and args.bench_class == "":
533-
save_checkpoint(model, optimizer, epoch, model_config,
534-
args.amp, args.output, args.model_name,
535-
local_rank, world_size)
532+
save_checkpoint(model, optimizer, scaler, epoch, model_config,
533+
args.output, args.model_name, local_rank, world_size)
536534
if local_rank == 0:
537535
DLLogger.flush()
538536

@@ -548,5 +546,6 @@ def main():
548546
if local_rank == 0:
549547
DLLogger.flush()
550548

549+
551550
if __name__ == '__main__':
552551
main()

PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, c):
5252
bias=False)
5353

5454
# Sample a random orthonormal matrix to initialize weights
55-
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
55+
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
5656

5757
# Ensure determinant is 1.0 not -1.0
5858
if torch.det(W) < 0:

0 commit comments

Comments
 (0)