@@ -39,10 +39,16 @@ class TransformerFeedables(NamedTuple(
39
39
("input_mask" , tf .Tensor )])):
40
40
"""Additional feedables used only by the Transformer-based decoder.
41
41
42
+ Follows the shape pattern of having batch_sized first dimension
43
+ shape(batch_size, ...)
44
+
42
45
Attributes:
43
46
input_sequence: The whole input sequence (embedded) that is fed into
44
47
the decoder in each decoding step.
45
- input_mask: Mask for masking finished sequences.
48
+ shape(batch, len, emb)
49
+ input_mask: Mask for masking finished sequences. The last dimension
50
+ is required for compatibility with the beam_search_decoder.
51
+ shape(batch, len, 1)
46
52
"""
47
53
48
54
@@ -392,14 +398,14 @@ def train_loop_result(self) -> LoopState:
392
398
decoder_ls = AutoregressiveDecoder .get_initial_loop_state (self )
393
399
394
400
input_sequence = self .embed_input_symbols (self .train_input_symbols )
401
+ input_mask = tf .transpose (self .train_mask )
402
+
395
403
last_layer = self .layer (
396
- self .depth , input_sequence , tf . transpose ( self . train_mask ) )
404
+ self .depth , input_sequence , input_mask )
397
405
398
- # We transpose input sequence and mask only to convey to
399
- # the defined shapes
400
406
tr_feedables = TransformerFeedables (
401
- input_sequence = tf . transpose ( input_sequence ) ,
402
- input_mask = self . train_mask )
407
+ input_sequence = input_sequence ,
408
+ input_mask = tf . expand_dims ( input_mask , - 1 ) )
403
409
404
410
# t_states shape: (batch, time, channels)
405
411
# dec_w shape: (channels, vocab)
@@ -453,11 +459,11 @@ def get_initial_loop_state(self) -> LoopState:
453
459
454
460
tr_feedables = TransformerFeedables (
455
461
input_sequence = tf .zeros (
456
- shape = [0 , self .batch_size , self .dimension ],
462
+ shape = [self .batch_size , 0 , self .dimension ],
457
463
dtype = tf .float32 ,
458
464
name = "input_sequence" ),
459
465
input_mask = tf .zeros (
460
- shape = [0 , self .batch_size ],
466
+ shape = [self .batch_size , 0 , 1 ],
461
467
dtype = tf .float32 ,
462
468
name = "input_mask" ))
463
469
@@ -486,16 +492,16 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
486
492
with tf .variable_scope (self ._variable_scope , reuse = tf .AUTO_REUSE ):
487
493
# shape (time, batch)
488
494
input_sequence = append_tensor (
489
- tr_feedables .input_sequence , feedables .embedded_input )
495
+ tr_feedables .input_sequence , feedables .embedded_input , 1 )
490
496
491
497
unfinished_mask = tf .to_float (tf .logical_not (feedables .finished ))
492
498
input_mask = append_tensor (
493
- tr_feedables .input_mask , unfinished_mask )
499
+ tr_feedables .input_mask ,
500
+ tf .expand_dims (unfinished_mask , - 1 ),
501
+ axis = 1 )
494
502
495
503
last_layer = self .layer (
496
- self .depth ,
497
- tf .transpose (input_sequence , [1 , 0 , 2 ]),
498
- tf .transpose (input_mask ))
504
+ self .depth , input_sequence , tf .squeeze (input_mask , - 1 ))
499
505
500
506
# (batch, state_size)
501
507
output_state = last_layer .temporal_states [:, - 1 , :]
0 commit comments