-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmain.py
66 lines (56 loc) · 1.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import argparse
import os
import json
import glob
import random
import collections
import math
import time
import sys
from data_utils import *
from ops import *
from model import *
from dnn_model import *
from cwgan import *
use_waveform = True
batch_size = 64
learning_rate = 1e-4
iters = 45000
mode = 'stage1' # stage1, stage2, test
log_path = 'stage1_log/'
model_path = 'stage1_model/model_20171109/'
model_path2 = 'stage2_model/model_20171109/'
test_path = model_path # switch between stage1 and stage2
test_list = "/mnt/gv0/user_sylar/segan_data/noisy_test_list"
record_name = "/data_wave.tfrecord"
if use_waveform:
G=Generator()
D=Discriminator()
else:
G=spec_Generator()
D=spec_Discriminator()
def check_dir(path_name):
if tf.gfile.Exists(path_name):
print('Folder already exists: {}\n'.format(path_name))
else:
tf.gfile.MkDir(path_name)
check_dir(model_path)
check_dir(model_path2)
with tf.device('cpu'):
reader = dataPreprocessor(record_name, use_waveform=use_waveform)
clean, noisy = reader.read_and_decode(batch_size=batch_size,num_threads=32)
#with tf.device('gpu'):
gan = GradientPenaltyWGAN(G,D,noisy,clean,log_path,model_path,model_path2,use_waveform,lr=learning_rate)
if mode=='test':
if use_waveform:
x_test = tf.placeholder("float", [None, 1, 16384, 1], name='test_noisy')
else:
x_test = tf.placeholder("float", [None, 1, 257, 32], name='test_noisy')
gan.test(x_test, test_path, test_list)
else:
gan.train(mode, iters)