-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
61 lines (41 loc) · 1.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
A script for WaveNet training
"""
import os
import wavenet.config as config
from wavenet.model import WaveNet
from wavenet.utils.data import DataLoader
class Trainer:
def __init__(self, args):
self.args = args
self.wavenet = WaveNet(args.layer_size, args.stack_size,
args.in_channels, args.res_channels,
lr=args.lr)
self.data_loader = DataLoader(args.data_dir, self.wavenet.receptive_fields,
args.sample_size, args.sample_rate, args.in_channels)
def infinite_batch(self):
while True:
for dataset in self.data_loader:
for inputs, targets in dataset:
yield inputs, targets
def run(self):
total_steps = 0
for inputs, targets in self.infinite_batch():
loss = self.wavenet.train(inputs, targets)
total_steps += 1
print('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss))
if total_steps > self.args.num_steps:
break
self.wavenet.save(args.model_dir)
def prepare_output_dir(args):
args.log_dir = os.path.join(args.output_dir, 'log')
args.model_dir = os.path.join(args.output_dir, 'model')
args.test_output_dir = os.path.join(args.output_dir, 'test')
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.test_output_dir, exist_ok=True)
if __name__ == '__main__':
args = config.parse_args()
prepare_output_dir(args)
trainer = Trainer(args)
trainer.run()