Skip to content

Commit 1761715

Browse files
committed
Update Zhihu RNN
1 parent 4bed5fa commit 1761715

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

Zhihu/RNN/Mnist.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,29 @@
1010
class MnistGenerator:
1111
def __init__(self, im=None, om=None):
1212
self._im, self._om = im, om
13+
self._cursor = self._indices = None
1314
self._x, self._y = DataUtil.get_dataset("mnist", "../../_Data/mnist.txt", quantized=True, one_hot=True)
1415
self._x = self._x.reshape(-1, 28, 28)
1516
self._x_train, self._x_test = self._x[:1800], self._x[1800:]
1617
self._y_train, self._y_test = self._y[:1800], self._y[1800:]
1718

19+
def refresh(self):
20+
self._cursor = 0
21+
self._indices = np.random.permutation(len(self._x_train))
22+
1823
def gen(self, batch, test=False):
1924
if batch == 0:
2025
if test:
2126
return self._x_test, self._y_test
2227
return self._x_train, self._y_train
23-
batch = np.random.choice(len(self._x_train), batch)
24-
return self._x_train[batch], self._y_train[batch]
28+
end = min(self._cursor + batch, len(self._x_train))
29+
start, self._cursor = self._cursor, end
30+
if start == end:
31+
self.refresh()
32+
end = batch
33+
start = self._cursor = 0
34+
indices = self._indices[start:end]
35+
return self._x_train[indices], self._y_train[indices]
2536

2637
if __name__ == '__main__':
2738
generator = MnistGenerator()
@@ -46,8 +57,8 @@ def gen(self, batch, test=False):
4657
net = tflearn.input_data(shape=[None, 28, 28])
4758
net = tf.concat(tflearn.lstm(net, 128, return_seq=True)[-3:], axis=1)
4859
net = tflearn.fully_connected(net, 10, activation='softmax')
49-
net = tflearn.regression(net, optimizer='adam', batch_size=64,
60+
net = tflearn.regression(net, optimizer='adam', batch_size=64, learning_rate=0.001,
5061
loss='categorical_crossentropy')
5162
model = tflearn.DNN(net, tensorboard_verbose=0)
52-
model.fit(*generator.gen(0), n_epoch=10, validation_set=generator.gen(0, True), show_metric=True)
63+
model.fit(*generator.gen(0), validation_set=generator.gen(0, True), show_metric=True)
5364
print("Time Cost: {}".format(time.time() - t))

Zhihu/RNN/RNN.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ def fit(self, im, om, generator, hidden_units=128, cell=LSTMCell):
7878
)
7979
self._get_output(rnn_outputs)
8080
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self._output, labels=self._tfy)
81-
train_step = tf.train.AdamOptimizer(0.01).minimize(loss)
81+
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
8282
self._sess.run(tf.global_variables_initializer())
8383
for _ in range(10):
84-
for __ in range(28):
84+
self._generator.refresh()
85+
for __ in range(29):
8586
x_batch, y_batch = self._generator.gen(64)
8687
self._sess.run(train_step, {self._tfx: x_batch, self._tfy: y_batch})
8788
self._verbose()

0 commit comments

Comments
 (0)