Skip to content

Commit

Permalink
Do not automatically add <EOS> token when packing
Browse files Browse the repository at this point in the history
A remnant of when the thought was to pack truncated sequences.

This also fixes problems with decoder-only attention masks.
  • Loading branch information
janEbert committed Jul 4, 2023
1 parent ff49685 commit a9736ff
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 113 deletions.
81 changes: 3 additions & 78 deletions megatron/data/t5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def _pack_samples(self, np_rng, idx):
prev_len,
prev_len_dec,
self.pad_id,
self.eos_id,
)
if maybe_lens is None:
# We are exceeding our sequence length already.
Expand Down Expand Up @@ -467,63 +466,24 @@ def update_samples_dict(
prev_len,
prev_len_dec,
pad_id,
eos_id,
):
_remove_padding(result_sample, pad_id)

len_enc = len(result_sample['text_enc'])
len_dec = len(result_sample['text_dec'])

if (
(
prev_len
+ len_enc
+ int(result_sample['text_enc'][-1] != eos_id)
) > max_seq_len
or (
prev_len_dec
+ len_dec
+ int(result_sample['text_dec'][-1] != eos_id)
) > max_seq_len_dec
prev_len + len_enc > max_seq_len
or prev_len_dec + len_dec > max_seq_len_dec
):
return None

eos_added = {
'text_enc': False,
'text_dec': False,
'labels': False,
}
for (key, is_enc) in zip(
['text_enc', 'text_dec', 'labels'],
[True, False, False],
):
for key in ['text_enc', 'text_dec', 'labels']:
curr_sample = result_sample[key]
offset, length = get_lens(
key, prev_len, prev_len_dec, len_enc, len_dec)
samples_dict[key][offset:offset + length] = curr_sample

# Add EOS token if not present.
if (
curr_sample[-1] != eos_id
or key == 'labels' and eos_added['text_dec']
):
samples_dict[key][offset + length] = eos_id
eos_added[key] = True

need_extras = {
'loss_mask': False,
'enc_mask': False,
'dec_mask': False,
'enc_dec_mask': [False, False],
}
if eos_added['text_enc']:
need_extras['enc_mask'] = True
need_extras['enc_dec_mask'][1] = True
if eos_added['text_dec']:
need_extras['loss_mask'] = True
need_extras['dec_mask'] = True
need_extras['enc_dec_mask'][0] = True

samples_dict['loss_mask'][
prev_len_dec:prev_len_dec + len_dec,
] += result_sample['loss_mask']
Expand All @@ -540,42 +500,7 @@ def update_samples_dict(
prev_len:prev_len + len_enc,
] += result_sample['enc_dec_mask']

if need_extras['loss_mask']:
samples_dict['loss_mask'][prev_len_dec + len_dec] = 1

for key in ['enc_mask', 'dec_mask']:
if need_extras[key]:
all_samples = samples_dict[key]
offset, length = get_lens(
key, prev_len, prev_len_dec, len_enc, len_dec)
all_samples[
offset + length,
offset:offset + length,
] = 1
all_samples[
offset:offset + length,
offset + length,
] = 1

if need_extras['enc_dec_mask'][0] or need_extras['enc_dec_mask'][1]:
all_samples = samples_dict['enc_dec_mask']
if need_extras['enc_dec_mask'][0]:
all_samples[
prev_len_dec + len_dec,
prev_len:prev_len + len_enc,
] = 1
elif need_extras['enc_dec_mask'][1]:
all_samples[
prev_len_dec:prev_len_dec + len_dec,
prev_len + len_enc,
] = 1
samples_dict['truncated'] += result_sample['truncated']

if eos_added['text_enc']:
len_enc += 1
if eos_added['text_dec']:
len_dec += 1

return len_enc, len_dec


Expand Down
36 changes: 1 addition & 35 deletions megatron/data/ul2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def _pack_samples(self, np_rng, idx, denoiser_index):
self.max_seq_length,
prev_len,
self.pad_id,
self.eos_id,
)
else:
maybe_lens = update_samples_dict(
Expand All @@ -239,7 +238,6 @@ def _pack_samples(self, np_rng, idx, denoiser_index):
prev_len,
prev_len_dec,
self.pad_id,
self.eos_id,
)
if maybe_lens is None:
# We are exceeding our sequence length already.
Expand Down Expand Up @@ -525,33 +523,17 @@ def update_samples_dict_decoder_only(
max_seq_len,
prev_len,
pad_id,
eos_id,
):
_remove_padding(result_sample, pad_id)
len_enc = len(result_sample['text'])

if (
(
prev_len
+ len_enc
+ int(result_sample['text'][-1] != eos_id)
) > max_seq_len
):
if prev_len + len_enc > max_seq_len:
return None

eos_added = False
for key in ['text', 'labels']:
curr_sample = result_sample[key]
samples_dict[key][prev_len:prev_len + len_enc] = curr_sample

# Add EOS token if not present.
if (
curr_sample[-1] != eos_id
or key == 'labels' and eos_added
):
samples_dict[key][prev_len + len_enc] = eos_id
eos_added = True

samples_dict['loss_mask'][
prev_len:prev_len + len_enc,
] += result_sample['loss_mask']
Expand All @@ -560,21 +542,5 @@ def update_samples_dict_decoder_only(
prev_len:prev_len + len_enc,
] += result_sample['dec_mask']

if eos_added:
samples_dict['loss_mask'][prev_len + len_enc] = 1

all_samples = samples_dict['dec_mask']
all_samples[
prev_len + len_enc,
prev_len:prev_len + len_enc,
] = 1
all_samples[
prev_len:prev_len + len_enc,
prev_len + len_enc,
] = 1

len_enc += 1

samples_dict['truncated'] += result_sample['truncated']

return len_enc

0 comments on commit a9736ff

Please sign in to comment.