Skip to content

Commit 84cac09

Browse files
committed
Update Zhihu RNN
1 parent dcd7fa3 commit 84cac09

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

Zhihu/RNN/RNN.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def fit(self, im, om, generator, hidden_units=128, cell=LSTMCell):
7373
initial_state=self._cell.zero_state(tf.shape(self._tfx)[0], tf.float32)
7474
)
7575
self._get_output(rnn_outputs)
76-
loss = -tf.reduce_mean(
77-
self._tfy * tf.log(self._output + 1e-8) + (1 - self._tfy) * tf.log(1 - self._output + 1e-8)
78-
)
76+
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self._output, labels=self._tfy)
7977
train_step = tf.train.AdamOptimizer(0.01).minimize(loss)
8078
self._sess.run(tf.global_variables_initializer())
8179
for _ in range(10):

0 commit comments

Comments
 (0)