-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain_DNN.py
84 lines (67 loc) · 2.09 KB
/
train_DNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 15 09:42:41 2016
@author: Jason
"""
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Flatten, Activation, SpatialDropout2D, Reshape
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import ELU, PReLU, LeakyReLU
from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D, AtrousConv2D, ZeroPadding2D
from keras.layers.local import LocallyConnected2D
from keras.optimizers import *
from keras.utils.io_utils import HDF5Matrix
from keras.callbacks import *
import time
import numpy as np
import h5py
import sys
if len(sys.argv) < 2:
print "Usage: python train_DNN.py data.h5"
sys.exit(1)
date = "20160508"
FRAMESIZE = 512
FRAMEWIDTH = 2
FBIN = FRAMESIZE//2+1
input_dim = FBIN*(FRAMEWIDTH*2+1)
BATCHSIZE = 200
EPOCH = 30
print 'model building...'
model = Sequential()
# model.add(Reshape((2048,), input_shape=(1,2048,1)))
model.add(Dense(2048, input_shape=(1285,)))
model.add(ELU())
# model.add(Dropout(0.05))
model.add(Dense(2048))
model.add(ELU())
# model.add(Dropout(0.05))
model.add(Dense(2048))
model.add(ELU())
# model.add(Dropout(0.05))
model.add(Dense(2048))
model.add(ELU())
# model.add(Dropout(0.05))
model.add(Dense(2048))
model.add(ELU())
# model.add(Dropout(0.05))
model.add(Dense(257))
#model.add(Activation('tanh'))
model.summary()
adam=Adam(lr=0.0002, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mse', optimizer=adam)
data_path = sys.argv[1] #"/mnt/hd-01/user_sylar/MHINTSYPD_100NS/data_257_spectrum.h5"
print 'data loading...'
X_train = HDF5Matrix(data_path,"trnoisy")
y_train = HDF5Matrix(data_path,"trclean")
checkpointer = ModelCheckpoint(
filepath="model.hdf5",
monitor="loss",
mode="min",
verbose=0,
save_best_only=True)
print 'training...'
hist=model.fit(X_train, y_train, epochs=EPOCH, batch_size=BATCHSIZE, verbose=1,shuffle="batch", callbacks=[checkpointer])