Skip to content

Commit 4bed5fa

Browse files
committed
Update Zhihu RNN
1 parent 84cac09 commit 4bed5fa

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

Diff for: Zhihu/RNN/Mnist.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,25 @@ def gen(self, batch, test=False):
2424
return self._x_train[batch], self._y_train[batch]
2525

2626
if __name__ == '__main__':
27-
print("=" * 60, "\n" + "My LSTM", "\n" + "-" * 60)
2827
generator = MnistGenerator()
29-
t = time.time()
28+
29+
print("=" * 60, "\n" + "My LSTM", "\n" + "-" * 60)
3030
tf.reset_default_graph()
31+
t = time.time()
3132
rnn = RNNWrapper()
3233
rnn.fit(28, 10, generator)
3334
print("Time Cost: {}".format(time.time() - t))
3435

3536
print("=" * 60, "\n" + "My Fast LSTM", "\n" + "-" * 60)
36-
generator = MnistGenerator()
37-
t = time.time()
3837
tf.reset_default_graph()
38+
t = time.time()
3939
rnn = RNNWrapper()
4040
rnn.fit(28, 10, generator, cell=FastLSTMCell)
4141
print("Time Cost: {}".format(time.time() - t))
4242

4343
print("=" * 60, "\n" + "Tflearn", "\n" + "-" * 60)
44-
generator = MnistGenerator()
45-
t = time.time()
4644
tf.reset_default_graph()
45+
t = time.time()
4746
net = tflearn.input_data(shape=[None, 28, 28])
4847
net = tf.concat(tflearn.lstm(net, 128, return_seq=True)[-3:], axis=1)
4948
net = tflearn.fully_connected(net, 10, activation='softmax')

Diff for: Zhihu/RNN/RNN.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ def __call__(self, x, state, scope="LSTM"):
1313
num_outputs=4 * self._num_units,
1414
activation_fn=None)
1515
r1, g1, g2, g3 = tf.split(gates, 4, 1)
16-
r1, g1, g3 = tf.nn.sigmoid(r1), tf.nn.sigmoid(g1), tf.nn.sigmoid(g3)
16+
r1 = tf.nn.sigmoid(r1)
17+
g1 = tf.nn.sigmoid(g1)
1718
g2 = tf.nn.tanh(g2)
19+
g3 = tf.nn.sigmoid(g3)
1820
h_new = h_old * r1 + g1 * g2
1921
s_new = tf.nn.tanh(h_new) * g3
2022
return s_new, tf.concat([s_new, h_new], 1)
@@ -33,8 +35,10 @@ def __call__(self, x, state, scope="LSTM"):
3335
num_outputs=4 * self._num_units,
3436
activation_fn=None)
3537
r1, g1, g2, g3 = tf.split(gates, 4, 1)
36-
r1, g1, g3 = tf.nn.sigmoid(r1), tf.nn.sigmoid(g1), tf.nn.sigmoid(g3)
38+
r1 = tf.nn.sigmoid(r1)
39+
g1 = tf.nn.sigmoid(g1)
3740
g2 = tf.nn.tanh(g2)
41+
g3 = tf.nn.sigmoid(g3)
3842
h_new = h_old * r1 + g1 * g2
3943
s_new = tf.nn.tanh(h_new) * g3
4044
return s_new, LSTMStateTuple(s_new, h_new)

0 commit comments

Comments
 (0)