@@ -157,13 +157,8 @@ def build_training_sample(sample, target_seq_length,
157
157
return train_sample
158
158
159
159
160
- def pad_and_convert_to_numpy (tokens , masked_positions ,
161
- masked_labels , pad_id ,
162
- max_seq_length , max_seq_length_dec ,
163
- masked_spans = None , bos_id = None ,
164
- eos_id = None , sentinel_tokens = None ):
165
- """Pad sequences and convert them to numpy."""
166
-
160
+ def merge_subsequent_masks (tokens , masked_spans = None , bos_id = None ,
161
+ eos_id = None , sentinel_tokens = None ):
167
162
sentinel_tokens = collections .deque (sentinel_tokens )
168
163
t5_input = []
169
164
(t5_decoder_in , t5_decoder_out ) = ([bos_id ], [])
@@ -189,6 +184,18 @@ def pad_and_convert_to_numpy(tokens, masked_positions,
189
184
190
185
# Add the remaining tokens to the t5 input
191
186
t5_input .extend (tokens [start_index :])
187
+ return t5_input , t5_decoder_in , t5_decoder_out
188
+
189
+
190
+ def pad_and_convert_to_numpy (tokens , masked_positions ,
191
+ masked_labels , pad_id ,
192
+ max_seq_length , max_seq_length_dec ,
193
+ masked_spans = None , bos_id = None ,
194
+ eos_id = None , sentinel_tokens = None ):
195
+ """Pad sequences and convert them to numpy."""
196
+
197
+ t5_input , t5_decoder_in , t5_decoder_out = merge_subsequent_masks (
198
+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
192
199
193
200
# assert (len(t5_input) - len(masked_spans)) + \
194
201
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
0 commit comments