Skip to content

Commit b69818d

Browse files
committed
Refactor span merging
1 parent 567e187 commit b69818d

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

megatron/data/t5_dataset.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,8 @@ def build_training_sample(sample, target_seq_length,
157157
return train_sample
158158

159159

160-
def pad_and_convert_to_numpy(tokens, masked_positions,
161-
masked_labels, pad_id,
162-
max_seq_length, max_seq_length_dec,
163-
masked_spans=None, bos_id=None,
164-
eos_id=None, sentinel_tokens=None):
165-
"""Pad sequences and convert them to numpy."""
166-
160+
def merge_subsequent_masks(tokens, masked_spans=None, bos_id=None,
161+
eos_id=None, sentinel_tokens=None):
167162
sentinel_tokens = collections.deque(sentinel_tokens)
168163
t5_input = []
169164
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
@@ -189,6 +184,18 @@ def pad_and_convert_to_numpy(tokens, masked_positions,
189184

190185
# Add the remaining tokens to the t5 input
191186
t5_input.extend(tokens[start_index:])
187+
return t5_input, t5_decoder_in, t5_decoder_out
188+
189+
190+
def pad_and_convert_to_numpy(tokens, masked_positions,
191+
masked_labels, pad_id,
192+
max_seq_length, max_seq_length_dec,
193+
masked_spans=None, bos_id=None,
194+
eos_id=None, sentinel_tokens=None):
195+
"""Pad sequences and convert them to numpy."""
196+
197+
t5_input, t5_decoder_in, t5_decoder_out = merge_subsequent_masks(
198+
tokens, masked_spans, bos_id, eos_id, sentinel_tokens)
192199

193200
# assert (len(t5_input) - len(masked_spans)) + \
194201
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)

0 commit comments

Comments
 (0)