Skip to content

Commit dee6855

Browse files
committed
autoregressive refactor: compatibility with beam_search_decoder
1 parent a803290 commit dee6855

17 files changed

+71
-42
lines changed

neuralmonkey/decoders/autoregressive.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class DecoderHistories(NamedTuple(
5656
This should only record decoding history and the decoding should not be
5757
dependent on these values.
5858
59+
Attributes defined here (and in the `other`) substructure should always
60+
be time-major (e.g., shape(time, batch, ...)).
61+
5962
Attributes:
6063
logits: A tensor of shape ``(time, batch, vocabulary)`` which contains
6164
the unnormalized output scores of words in a vocabulary.
@@ -93,6 +96,9 @@ class DecoderFeedables(NamedTuple(
9396
9497
The decoder should be able to decode in each step only using this.
9598
99+
Attributes defined here (and in the `other`) substructure should always
100+
be batch-major (e.g., shape(batch, ...)).
101+
96102
Attributes:
97103
step: A scalar int tensor, stores the number of the current time step.
98104
finished: A boolean tensor of shape ``(batch)``, which says whether
@@ -276,7 +282,7 @@ def train_logits(self) -> tf.Tensor:
276282
@tensor
277283
def train_output_states(self) -> tf.Tensor:
278284
train_result = LoopState(*self.train_loop_result)
279-
return train_result.histories.decoder_outputs
285+
return train_result.histories.output_states
280286

281287
@tensor
282288
def train_logprobs(self) -> tf.Tensor:
@@ -319,12 +325,12 @@ def runtime_logits(self) -> tf.Tensor:
319325
@tensor
320326
def runtime_output_states(self) -> tf.Tensor:
321327
runtime_result = LoopState(*self.runtime_loop_result)
322-
return runtime_result.histories.decoder_outputs
328+
return runtime_result.histories.output_states
323329

324330
@tensor
325331
def runtime_mask(self) -> tf.Tensor:
326332
runtime_result = LoopState(*self.runtime_loop_result)
327-
return runtime_result.histories.mask
333+
return runtime_result.histories.output_mask
328334

329335
@tensor
330336
def decoded(self) -> tf.Tensor:

neuralmonkey/decoders/beam_search_decoder.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,14 @@ def get_initial_loop_state(self) -> BeamSearchLoopState:
279279
# time, the placeholder replacement is done on the whole structures, as
280280
# you can see below.
281281

282+
logits = dec_next_ls.histories.logits[-1, :, :]
282283
search_state = SearchState(
283284
logprob_sum=tf.tile(
284285
tf.expand_dims([0.0] + [-INF] * (self.beam_size - 1), 0),
285286
[self.batch_size, 1],
286287
name="bs_logprob_sum"),
287288
prev_logprobs=tf.reshape(
288-
tf.nn.log_softmax(dec_next_ls.feedables.prev_logits),
289+
tf.nn.log_softmax(logits),
289290
[self.batch_size, self.beam_size, len(self.vocabulary)]),
290291
lengths=tf.zeros(
291292
[self.batch_size, self.beam_size], dtype=tf.int32,
@@ -296,13 +297,14 @@ def get_initial_loop_state(self) -> BeamSearchLoopState:
296297
# We add the input_symbol to token_ids during search_results
297298
# initialization for simpler beam_body implementation
298299

300+
input_symbols = dec_next_ls.histories.output_symbols[-1, :]
299301
search_results = SearchResults(
300302
scores=tf.zeros(
301303
shape=[self.batch_size, self.beam_size],
302304
dtype=tf.float32,
303305
name="beam_scores"),
304306
token_ids=tf.reshape(
305-
feedables.input_symbol,
307+
input_symbols,
306308
[1, self.batch_size, self.beam_size],
307309
name="beam_tokens"))
308310

@@ -505,11 +507,15 @@ def body(*args: Any) -> BeamSearchLoopState:
505507
dec_loop_state.feedables)
506508

507509
next_feedables = next_feedables._replace(
508-
input_symbol=tf.reshape(next_word_ids, [-1]),
510+
embedded_input=self.parent_decoder.embed_input_symbols(
511+
tf.reshape(next_word_ids, [-1])),
509512
finished=tf.reshape(next_finished, [-1]))
510513

511514
# histories have shape [len, batch, ...]
512515
def gather_fn(x):
516+
if len(x.shape.dims) < 2:
517+
return x
518+
513519
return partial_transpose(
514520
gather_flat(
515521
partial_transpose(x, [1, 0]),
@@ -528,10 +534,11 @@ def gather_fn(x):
528534
# CALL THE DECODER BODY FUNCTION
529535
next_loop_state = decoder_body(*dec_loop_state)
530536

537+
logits = next_loop_state.histories.logits[-1, :, :]
531538
next_search_state = SearchState(
532539
logprob_sum=next_beam_logprob_sum,
533540
prev_logprobs=tf.reshape(
534-
tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
541+
tf.nn.log_softmax(logits),
535542
[self.batch_size, self.beam_size, len(self.vocabulary)]),
536543
lengths=next_beam_lengths,
537544
finished=next_finished)

neuralmonkey/decoders/decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,12 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
334334
cell_output, self.dropout_keep_prob, self.train_mode)
335335

336336
with tf.name_scope("rnn_output_projection"):
337+
if self.embedding_size != self.output_dimension:
338+
raise ValueError(
339+
"The dimension ({}) of the output projection must be "
340+
"same as the dimension of the input embedding "
341+
"({})".format(self.output_dimension,
342+
self.embedding_size))
337343
# pylint: disable=not-callable
338344
output = self.output_projection(
339345
cell_output, loop_state.feedables.embedded_input,

neuralmonkey/decoders/transformer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,16 @@ class TransformerFeedables(NamedTuple(
3939
("input_mask", tf.Tensor)])):
4040
"""Additional feedables used only by the Transformer-based decoder.
4141
42+
Follows the shape pattern of having batch_sized first dimension
43+
shape(batch_size, ...)
44+
4245
Attributes:
4346
input_sequence: The whole input sequence (embedded) that is fed into
4447
the decoder in each decoding step.
45-
input_mask: Mask for masking finished sequences.
48+
shape(batch, len, emb)
49+
input_mask: Mask for masking finished sequences. The last dimension
50+
is required for compatibility with the beam_search_decoder.
51+
shape(batch, len, 1)
4652
"""
4753

4854

@@ -392,14 +398,14 @@ def train_loop_result(self) -> LoopState:
392398
decoder_ls = AutoregressiveDecoder.get_initial_loop_state(self)
393399

394400
input_sequence = self.embed_input_symbols(self.train_input_symbols)
401+
input_mask = tf.transpose(self.train_mask)
402+
395403
last_layer = self.layer(
396-
self.depth, input_sequence, tf.transpose(self.train_mask))
404+
self.depth, input_sequence, input_mask)
397405

398-
# We transpose input sequence and mask only to convey to
399-
# the defined shapes
400406
tr_feedables = TransformerFeedables(
401-
input_sequence=tf.transpose(input_sequence),
402-
input_mask=self.train_mask)
407+
input_sequence=input_sequence,
408+
input_mask=tf.expand_dims(input_mask, -1))
403409

404410
# t_states shape: (batch, time, channels)
405411
# dec_w shape: (channels, vocab)
@@ -453,11 +459,11 @@ def get_initial_loop_state(self) -> LoopState:
453459

454460
tr_feedables = TransformerFeedables(
455461
input_sequence=tf.zeros(
456-
shape=[0, self.batch_size, self.dimension],
462+
shape=[self.batch_size, 0, self.dimension],
457463
dtype=tf.float32,
458464
name="input_sequence"),
459465
input_mask=tf.zeros(
460-
shape=[0, self.batch_size],
466+
shape=[self.batch_size, 0, 1],
461467
dtype=tf.float32,
462468
name="input_mask"))
463469

@@ -486,16 +492,16 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
486492
with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
487493
# shape (time, batch)
488494
input_sequence = append_tensor(
489-
tr_feedables.input_sequence, feedables.embedded_input)
495+
tr_feedables.input_sequence, feedables.embedded_input, 1)
490496

491497
unfinished_mask = tf.to_float(tf.logical_not(feedables.finished))
492498
input_mask = append_tensor(
493-
tr_feedables.input_mask, unfinished_mask)
499+
tr_feedables.input_mask,
500+
tf.expand_dims(unfinished_mask, -1),
501+
axis=1)
494502

495503
last_layer = self.layer(
496-
self.depth,
497-
tf.transpose(input_sequence, [1, 0, 2]),
498-
tf.transpose(input_mask))
504+
self.depth, input_sequence, tf.squeeze(input_mask, -1))
499505

500506
# (batch, state_size)
501507
output_state = last_layer.temporal_states[:, -1, :]

neuralmonkey/tf_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,17 @@ def layer_norm(x: tf.Tensor, epsilon: float = 1e-6) -> tf.Tensor:
219219
return norm_x * gamma + beta
220220

221221

222-
def append_tensor(tensor: tf.Tensor, appendval: tf.Tensor) -> tf.Tensor:
222+
def append_tensor(tensor: tf.Tensor,
223+
appendval: tf.Tensor,
224+
axis: int = 0) -> tf.Tensor:
223225
"""Append an ``N``-D Tensor to an ``(N+1)``-D Tensor.
224226
225227
Arguments:
226228
tensor: The original Tensor
227229
appendval: The Tensor to add
230+
axis: Which axis should we use
228231
229232
Returns:
230233
An ``(N+1)``-D Tensor with ``appendval`` on the last position.
231234
"""
232-
return tf.concat([tensor, tf.expand_dims(appendval, 0)], 0)
235+
return tf.concat([tensor, tf.expand_dims(appendval, axis)], axis)

neuralmonkey/trainers/rl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _score_with_reward_function(references: np.array,
123123
sample_loop_result = self.decoder.decoding_loop(
124124
train_mode=False, sample=True, temperature=self.temperature)
125125
sample_logits = sample_loop_result.histories.logits
126-
sample_decoded = sample_loop_result.histories.outputs
126+
sample_decoded = sample_loop_result.histories.output_symbols
127127

128128
# rewards, shape (batch)
129129
# simulate from reference

tests/bahdanau.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ supress_unk=True
8585

8686
[dec_maxout_output]
8787
class=decoders.output_projection.maxout_output
88-
maxout_size=7
88+
maxout_size=9
8989

9090
[trainer1]
9191
class=trainers.cross_entropy_trainer.CrossEntropyTrainer

tests/bpe.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ encoder=<encoder>
8585
class=decoders.decoder.Decoder
8686
name="decoder"
8787
encoders=[<encoder>]
88-
embedding_size=9
88+
embedding_size=10
8989
attentions=[<attention>]
9090
dropout_keep_prob=0.5
9191
data_id="target_bpe"

tests/factored.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ name="decoder"
6363
encoders=[<encoder>]
6464
attentions=[<attention>]
6565
rnn_size=32
66-
embedding_size=20
66+
embedding_size=32
6767
dropout_keep_prob=0.5
6868
data_id="target"
6969
max_output_len=10

tests/flat-multiattention.ini

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class=decoders.decoder.Decoder
7272
name="decoder_flat_noshare_nosentinel"
7373
attentions=[<flat_noshare_nosentinel>]
7474
encoders=[<encoder>, <imagenet>]
75-
rnn_size=2
75+
rnn_size=3
7676
embedding_size=3
7777
dropout_keep_prob=0.5
7878
data_id="target"
@@ -92,7 +92,7 @@ class=decoders.decoder.Decoder
9292
name="decoder_flat_share_nosentinel"
9393
attentions=[<flat_share_nosentinel>]
9494
encoders=[<encoder>, <imagenet>]
95-
rnn_size=2
95+
rnn_size=3
9696
embedding_size=3
9797
dropout_keep_prob=0.5
9898
data_id="target"
@@ -112,7 +112,7 @@ class=decoders.decoder.Decoder
112112
name="decoder_flat_share_sentinel"
113113
attentions=[<flat_share_sentinel>]
114114
encoders=[<encoder>, <imagenet>]
115-
rnn_size=2
115+
rnn_size=3
116116
embedding_size=3
117117
dropout_keep_prob=0.5
118118
data_id="target"
@@ -132,7 +132,7 @@ class=decoders.decoder.Decoder
132132
name="decoder_flat_noshare_sentinel"
133133
attentions=[<flat_noshare_sentinel>]
134134
encoders=[<encoder>, <imagenet>]
135-
rnn_size=2
135+
rnn_size=3
136136
embedding_size=3
137137
dropout_keep_prob=0.5
138138
data_id="target"

0 commit comments

Comments
 (0)