-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
172 lines (118 loc) · 4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from deepml.solvers import rmsprop, sgd, adadelta
from deepml.activations import relu, tanh, softmax, sigmoid
from deepml.utils import floatX
import theano
import numpy as np
import pylab as pl
import time
import cPickle as pickle
import os
from utils import align_x, align_y, iter_batches, batched_wer, filter_blanks
from h5monitor import dump_h5, dump_h5_var
import h5py
from arsg_cnn_bi_y import create_model
from gen_text import tokens, render_batch
### RND ###
#srng = RandomStreams()
### DATA ###
monitor_file = 'stats.h5'
### META-PARAMETERS ###
#n_in = train_x_shp[0,1]
n_in = 20
n_hid = 128
n_cyc = 512
n_enc = 64
n_out = len(tokens)+1
blank_symbol = len(tokens)
batch_size = 32
test_batch_size = 128
seq_len = (10,35)
CER_RATE = 1.
### ROUTINES ###
def decode(y):
return np.array([tokens[t] for t in y])
### MODEL ###
def train(model):
global seq_len
[x, y], out, cost, params, alpha = model
tester = theano.function(inputs=[x,y], outputs=[cost, out, alpha])
grad_updates = adadelta(
eps=1e-9,
cost = cost,
params = params,
#lr=3.,
#grad_norm=1.
)
solver = theano.function(
inputs=[x, y],
outputs=cost,
updates = grad_updates,
)
# alpha: bs, len_y, len_x
costs = []
t0 = time.time()
for i in range(100000):
bx, by = render_batch(seq_len, batch_size=batch_size)
bx_ = np.float32(bx.transpose(0,2,1))/255.
#bx_, _ = align_x(bx)
by_, _ = align_y(by, filler=blank_symbol)
loss = solver(bx_, by_)
costs.append(loss)
# monitoring training process
if i and i%100 == 0:
print 'Iteration %d, time: %.4f, loss %.8f' % (
i, time.time() - t0, np.mean(costs))
dump_h5(monitor_file, prefix='train_loss',
data=[np.mean(costs)])
costs = []
t0 = time.time()
# testing
if i and i%100 == 0:
test_costs = []
t1 = time.time()
cer = []
for j in range(32):
bx, by = render_batch(seq_len, batch_size=batch_size)
#bx_, _ = align_x(bx)
bx_ = np.float32(bx.transpose(0,2,1))/255.
by_, _ = align_y(by, filler=blank_symbol)
loss, y_pred, alpha = tester(bx_, by_)
y_filt = filter_blanks(y_pred, blank_symbol)
wer = batched_wer(by, y_filt)
test_costs.append(loss)
cer.append(1-wer)
with open('model.pkl', 'w') as h:
pickle.dump(model, h)
dump_h5(monitor_file, prefix='test_loss',
data=[np.mean(test_costs)])
dump_h5(monitor_file, prefix='test_cer', data=[np.mean(cer)])
dump_h5_var(monitor_file,
prefix='test_alpha',
prefix_shape='test_alpha_shp',
data=alpha[0:1])
dump_h5_var(monitor_file,
prefix='test_x',
prefix_shape='test_x_shp',
data=[bx_[0]])
dump_h5_var(monitor_file,
prefix='test_y',
prefix_shape='test_y_shp',
data=[decode(y_filt[0])])
print 'Iteration: %d, time: %.4f, test loss: %.8f, CER*: %.4f' % (i, time.time() - t1, np.mean(test_costs), np.mean(cer))
if np.mean(cer) > CER_RATE:
seq_len = (seq_len[0] + 1, seq_len[1] + 2)
print 'Increasing sequence length:', seq_len
i+=1
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', default=None,
help='Pre-trained model')
args = parser.parse_args()
if args.model is None:
#model = create_model(n_in, n_out, n_enc, n_hid, n_cyc, batch_size)
model = create_model(n_in, n_out, n_enc, n_hid, n_cyc)
else:
with open(args.model, 'r') as h:
model = pickle.load(h)
train(model)