Skip to content

Commit d0da72f

Browse files
jlibovickyjindrahelcl
authored andcommitted
fix </s> striping
1 parent bc0caeb commit d0da72f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

neuralmonkey/decoders/autoregressive.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
479479

480480
@tensor
481481
def temporal_states(self) -> tf.Tensor:
482+
# strip the last symbol which is </s>
482483
return tf.cond(
483484
self.train_mode,
484-
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2],
485+
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1],
485486
lambda: tf.transpose(
486-
self.runtime_output_states, [1, 0, 2])[:, :-2])
487+
self.runtime_output_states, [1, 0, 2])[:, :-1])
487488

488489
@tensor
489490
def temporal_mask(self) -> tf.Tensor:
491+
# strip the last symbol which is </s>
490492
return tf.cond(
491493
self.train_mode,
492494
lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1],

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SequenceLabeler(ModelPart):
1717
1818
Note that when the labeler is stacked on an autoregressive decoder, it
1919
labels the symbol that is currently generated by the decoder, i.e., the
20-
decoder's state has not yet been updated by putting the decoded symbol on
20+
decoder state has not yet been updated by putting the decoded symbol on
2121
its input.
2222
"""
2323

0 commit comments

Comments
 (0)