Skip to content

Commit 89eb9c2

Browse files
varisdjlibovicky
authored andcommitted
get_initial_loop_state code factorization
1 parent 8f7f2e0 commit 89eb9c2

File tree

3 files changed

+31
-29
lines changed

3 files changed

+31
-29
lines changed

neuralmonkey/decoders/autoregressive.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,14 @@ def runtime_logprobs(self) -> tf.Tensor:
378378
def output_dimension(self) -> int:
379379
raise NotImplementedError("Abstract property")
380380

381-
def get_initial_loop_state(self) -> LoopState:
381+
def get_initial_feedables(self) -> DecoderFeedables:
382+
return DecoderFeedables(
383+
step=tf.constant(0, tf.int32),
384+
finished=tf.zeros([self.batch_size], dtype=tf.bool),
385+
embedded_input=self.embed_input_symbols(self.go_symbols),
386+
other=None)
382387

388+
def get_initial_histories(self) -> DecoderHistories:
383389
output_states = tf.zeros(
384390
shape=[0, self.batch_size, self.embedding_size],
385391
dtype=tf.float32,
@@ -400,25 +406,21 @@ def get_initial_loop_state(self) -> LoopState:
400406
dtype=tf.float32,
401407
name="hist_logits")
402408

403-
feedables = DecoderFeedables(
404-
step=tf.constant(0, tf.int32),
405-
finished=tf.zeros([self.batch_size], dtype=tf.bool),
406-
embedded_input=self.embed_input_symbols(self.go_symbols),
407-
other=None)
408-
409-
histories = DecoderHistories(
409+
return DecoderHistories(
410410
logits=logits,
411411
output_states=output_states,
412412
output_mask=output_mask,
413413
output_symbols=output_symbols,
414414
other=None)
415415

416-
constants = DecoderConstants(train_inputs=self.train_inputs)
416+
def get_initial_constants(self) -> DecoderConstants:
417+
return DecoderConstants(train_inputs=self.train_inputs)
417418

419+
def get_initial_loop_state(self) -> LoopState:
418420
return LoopState(
419-
histories=histories,
420-
constants=constants,
421-
feedables=feedables)
421+
feedables=self.get_initial_feedables(),
422+
histories=self.get_initial_histories(),
423+
constants=self.get_initial_constants())
422424

423425
def loop_continue_criterion(self, *args) -> tf.Tensor:
424426
"""Decide whether to break out of the while loop.

neuralmonkey/decoders/decoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typeguard import check_argument_types
55

66
from neuralmonkey.decoders.autoregressive import (
7-
AutoregressiveDecoder, LoopState)
7+
AutoregressiveDecoder, DecoderFeedables, DecoderHistories, LoopState)
88
from neuralmonkey.attention.base_attention import BaseAttention
99
from neuralmonkey.vocabulary import Vocabulary
1010
from neuralmonkey.model.sequence import EmbeddedSequence
@@ -357,17 +357,20 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
357357

358358
return (output, new_feedables, new_histories)
359359

360-
def get_initial_loop_state(self) -> LoopState:
361-
default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
362-
feedables = default_ls.feedables
363-
histories = default_ls.histories
360+
def get_initial_feedables(self) -> DecoderFeedables:
361+
feedables = AutoregressiveDecoder.get_initial_feedables(self)
364362

365363
rnn_feedables = RNNFeedables(
366364
prev_contexts=[tf.zeros([self.batch_size, a.context_vector_size])
367365
for a in self.attentions],
368366
prev_rnn_state=self.initial_state,
369367
prev_rnn_output=self.initial_state)
370368

369+
return feedables._replace(other=rnn_feedables)
370+
371+
def get_initial_histories(self) -> DecoderHistories:
372+
histories = AutoregressiveDecoder.get_initial_histories(self)
373+
371374
rnn_histories = RNNHistories(
372375
rnn_outputs=tf.zeros(
373376
shape=[0, self.batch_size, self.rnn_size],
@@ -376,10 +379,7 @@ def get_initial_loop_state(self) -> LoopState:
376379
attention_histories=[a.initial_loop_state()
377380
for a in self.attentions if a is not None])
378381

379-
return LoopState(
380-
histories=histories._replace(other=rnn_histories),
381-
constants=default_ls.constants,
382-
feedables=feedables._replace(other=rnn_feedables))
382+
return histories._replace(other=rnn_histories)
383383

384384
def finalize_loop(self, final_loop_state: LoopState,
385385
train_mode: bool) -> None:

neuralmonkey/decoders/transformer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,8 @@ def train_loop_result(self) -> LoopState:
452452
histories=histories,
453453
constants=decoder_ls.constants)
454454

455-
def get_initial_loop_state(self) -> LoopState:
456-
default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
457-
feedables = default_ls.feedables
458-
histories = default_ls.histories
455+
def get_initial_feedables(self) -> DecoderFeedables:
456+
feedables = AutoregressiveDecoder.get_initial_feedables(self)
459457

460458
tr_feedables = TransformerFeedables(
461459
input_sequence=tf.zeros(
@@ -467,6 +465,11 @@ def get_initial_loop_state(self) -> LoopState:
467465
dtype=tf.float32,
468466
name="input_mask"))
469467

468+
return feedables._replace(other=tr_feedables)
469+
470+
def get_initial_histories(self) -> DecoderHistories:
471+
histories = AutoregressiveDecoder.get_initial_histories(self)
472+
470473
# TODO: record histories properly
471474
tr_histories = tf.zeros([])
472475
# tr_histories = TransformerHistories(
@@ -479,10 +482,7 @@ def get_initial_loop_state(self) -> LoopState:
479482
# self.n_heads_enc)
480483
# for a in range(self.depth)])
481484

482-
return LoopState(
483-
histories=histories._replace(other=tr_histories),
484-
constants=default_ls.constants,
485-
feedables=feedables._replace(other=tr_feedables))
485+
return histories._replace(other=tr_histories)
486486

487487
def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
488488
feedables = loop_state.feedables

0 commit comments

Comments
 (0)