Skip to content

Commit ab85833

Browse files
committed
Support UL2 for decoder-only models
1 parent deed87f commit ab85833

File tree

5 files changed

+216
-69
lines changed

5 files changed

+216
-69
lines changed

megatron/arguments.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
import deepspeed
2626

27-
from megatron.enums import PositionEmbeddingType
27+
from megatron.enums import PositionEmbeddingType, UL2ModelType
2828
import megatron
2929
from megatron.logging import log_levels
3030

@@ -311,6 +311,18 @@ def parse_args(extra_args_provider=None, defaults={},
311311
)
312312
args.skip_train_iteration_range = skip_train_iteration_range
313313

314+
args.ul2_model_type = UL2ModelType(args.ul2_model_type)
315+
if (
316+
args.ul2_model_type is not UL2ModelType.ENCODER_DECODER
317+
and args.decoder_seq_length is not None
318+
):
319+
print(
320+
'WARNING: `--decoder_seq_length` is ignored when '
321+
'`--ul2-model-type` is not "',
322+
UL2ModelType.ENCODER_DECODER.value,
323+
'"!'
324+
)
325+
314326
if args.use_bnb_optimizer:
315327
try:
316328
import bitsandbytes as bnb
@@ -1028,6 +1040,11 @@ def _add_vit_args(parser):
10281040
def _add_ul2_args(parser):
10291041
group = parser.add_argument_group(title="UL2")
10301042

1043+
group.add_argument('--ul2-model-type', type=str, default='ED',
1044+
choices=['ED', 'ND', 'CD'],
1045+
help='What type of model to use for UL2 pretraining. '
1046+
'ED = encoder-decoder; ND = non-causal decoder-only; '
1047+
'CD = causal decoder-only')
10311048
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
10321049
default=None,
10331050
help='Probability of each denoising objective to be '

megatron/data/dataset_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def build_dataset(index, name):
597597
args = get_args()
598598
dataset = UL2Dataset(
599599
indexed_dataset=indexed_dataset,
600+
model_type=args.ul2_model_type,
600601
denoiser_ratios=args.ul2_denoiser_ratios,
601602
denoisers=args.ul2_denoisers,
602603
mean_span_lengths=args.ul2_mean_span_lengths,

megatron/data/ul2_dataset.py

Lines changed: 117 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
"""UL2-style dataset."""
1717

18+
import math
19+
1820
import numpy as np
1921

2022
from megatron import get_tokenizer
@@ -23,16 +25,34 @@
2325
get_samples_mapping,
2426
SamplingStyle
2527
)
26-
from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset
28+
from megatron.data.t5_dataset import (
29+
make_history_mask,
30+
merge_subsequent_masks,
31+
pad_and_convert_to_numpy,
32+
T5Dataset,
33+
)
34+
from megatron.enums import UL2ModelType
35+
36+
37+
def is_decoder_only(ul2_model_type):
38+
"""Return whether we use a decoder-only model."""
39+
assert isinstance(ul2_model_type, UL2ModelType)
40+
return ul2_model_type is not UL2ModelType.ENCODER_DECODER
41+
42+
43+
def is_prefix_lm(ul2_model_type):
44+
"""Return whether we use a non-causal decoder-only model."""
45+
assert isinstance(ul2_model_type, UL2ModelType)
46+
return ul2_model_type is UL2ModelType.NON_CAUSAL_DECODER
2747

2848

2949
class UL2Dataset(T5Dataset):
3050

3151
def __init__(self, name, indexed_dataset, data_prefix,
32-
num_epochs, max_num_samples, denoiser_ratios,
33-
denoisers, mean_span_lengths, mask_ratios,
34-
denoiser_tokens, max_seq_length, max_seq_length_dec,
35-
short_seq_prob, seed):
52+
num_epochs, max_num_samples, model_type,
53+
denoiser_ratios, denoisers, mean_span_lengths,
54+
mask_ratios, denoiser_tokens, max_seq_length,
55+
max_seq_length_dec, short_seq_prob, seed):
3656

3757
if denoiser_ratios is None:
3858
# Uniform distribution by default.
@@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
5272
short_seq_prob, seed)
5373

5474
# Params to store.
75+
self.model_type = model_type
5576
self.denoiser_ratios = [
5677
denoiser_ratio / sum(denoiser_ratios)
5778
for denoiser_ratio in denoiser_ratios
@@ -97,21 +118,21 @@ def __getitem__(self, idx):
97118
self.vocab_id_to_token_dict,
98119
self.cls_ids, self.sep_id,
99120
self.mask_id, self.pad_id,
100-
self.denoiser_ratios, self.denoisers,
101-
self.mean_span_lengths, self.mask_ratios,
102-
np_rng,
103-
self.bos_id, self.eos_id,
104-
self.sentinel_tokens)
121+
self.model_type, self.denoiser_ratios,
122+
self.denoisers, self.mean_span_lengths,
123+
self.mask_ratios, np_rng, self.bos_id,
124+
self.eos_id, self.sentinel_tokens)
105125

106126

107127
def build_training_sample(sample, target_seq_length,
108128
max_seq_length, max_seq_length_dec,
109129
vocab_id_list, vocab_id_to_token_dict,
110130
cls_ids, sep_id, mask_id, pad_id,
111-
denoiser_ratios, denoisers,
112-
mean_span_lengths, mask_ratios,
113-
np_rng, bos_id=None,
114-
eos_id=None, sentinel_tokens=None):
131+
model_type, denoiser_ratios,
132+
denoisers, mean_span_lengths,
133+
mask_ratios, np_rng,
134+
bos_id=None, eos_id=None,
135+
sentinel_tokens=None):
115136
"""Build training sample.
116137
117138
Arguments:
@@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length,
125146
sep_id: Separator id.
126147
mask_id: Mask token id.
127148
pad_id: Padding token id.
149+
model_type: What type of model is used.
128150
denoiser_ratios: Probability of each denoising objective to be selected.
129151
denoisers: What type of UL2 denoising objective the other UL2
130152
configurations refer to.
@@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length,
139161
sentinel_tokens: unique value to be substituted for every replaced span
140162
"""
141163

164+
# Denoiser selection
165+
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
166+
denoiser = denoisers[denoiser_index]
167+
masked_lm_prob = mask_ratios[denoiser_index]
168+
142169
assert target_seq_length <= max_seq_length
143170

144171
# flatten sentences into one list
145172
tokens = [token for sentence in sample for token in sentence]
146173

147-
# Truncate to `target_sequence_length`.
148174
max_num_tokens = target_seq_length
149-
truncated = len(tokens) > max_num_tokens
150-
tokens = tokens[:max_num_tokens]
151-
152-
# Denoiser selection
153-
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
154-
denoiser = denoisers[denoiser_index]
155-
masked_lm_prob = mask_ratios[denoiser_index]
156-
mean_ngrams = mean_span_lengths[denoiser_index]
157-
if mean_ngrams < 1:
158-
mean_ngrams = round(len(tokens) * mean_ngrams)
159-
max_ngrams = mean_ngrams * 2 - 1
175+
if is_decoder_only(model_type):
176+
# Keep space for repeated `extra_id` tokens; not the most data
177+
# efficient since we calculate this based on the maximum number
178+
# of possible `extra_id` tokens.
179+
safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob))
180+
truncated = len(tokens) > safe_max_seq_len
181+
tokens = tokens[:safe_max_seq_len]
182+
else:
183+
# Truncate to `target_sequence_length`.
184+
truncated = len(tokens) > max_num_tokens
185+
tokens = tokens[:max_num_tokens]
160186

161187
# Prepend objective token.
162188
cls_id = cls_ids.get(denoiser)
@@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length,
166192

167193
# Masking.
168194
max_predictions_per_seq = masked_lm_prob * len(tokens)
195+
mean_ngrams = mean_span_lengths[denoiser_index]
196+
if mean_ngrams < 1:
197+
mean_ngrams = round(len(tokens) * mean_ngrams)
198+
max_ngrams = mean_ngrams * 2 - 1
199+
169200
if denoiser == 'R' or denoiser == 'X':
170201
sampling_style = SamplingStyle.NORMAL
171202
prefix_lm = False
@@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length,
183214
sampling_style=sampling_style, prefix_lm=prefix_lm,
184215
)
185216

186-
# Padding.
187-
tokens_enc, tokens_dec_in, labels, enc_mask, \
188-
dec_mask, enc_dec_mask, loss_mask \
189-
= pad_and_convert_to_numpy(tokens, masked_positions,
190-
masked_labels, pad_id, max_seq_length,
191-
max_seq_length_dec, masked_spans,
192-
bos_id, eos_id, sentinel_tokens)
193-
194-
train_sample = {
195-
'text_enc': tokens_enc,
196-
'text_dec': tokens_dec_in,
197-
'labels': labels,
198-
'loss_mask': loss_mask,
199-
'truncated': int(truncated),
200-
'enc_mask': enc_mask,
201-
'dec_mask': dec_mask,
202-
'enc_dec_mask': enc_dec_mask,
203-
}
217+
if is_decoder_only(model_type):
218+
# Concatenate to one sequence.
219+
tokens_enc, tokens_dec_in, labels = merge_subsequent_masks(
220+
tokens, masked_spans, bos_id, eos_id, sentinel_tokens)
221+
222+
# Move EOS tokens to end of sequence.
223+
while tokens_enc[-1] == eos_id:
224+
del tokens_enc[-1]
225+
tokens_dec_in.append(eos_id)
226+
labels.append(eos_id)
227+
228+
num_labels = len(labels)
229+
230+
# Move BOS token to start of sequence.
231+
tokens_dec_in = tokens_dec_in[1:]
232+
tokens = np.concatenate([
233+
np.array([bos_id], dtype=np.int64),
234+
tokens_enc,
235+
np.array([sep_id], dtype=np.int64),
236+
tokens_dec_in,
237+
])
238+
labels = np.concatenate([
239+
tokens_enc,
240+
np.array([sep_id], dtype=np.int64),
241+
labels,
242+
])
243+
244+
loss_mask = np.zeros(len(tokens), dtype=np.int64)
245+
loss_mask[-num_labels:] = 1
246+
247+
dec_mask = make_history_mask(tokens)
248+
if is_prefix_lm(model_type):
249+
dec_mask[:-num_labels, :-num_labels] = 1
250+
251+
train_sample = {
252+
'text': tokens,
253+
'labels': labels,
254+
'loss_mask': loss_mask,
255+
'truncated': int(truncated),
256+
'dec_mask': dec_mask,
257+
}
258+
else:
259+
# Padding.
260+
tokens_enc, tokens_dec_in, labels, enc_mask, \
261+
dec_mask, enc_dec_mask, loss_mask \
262+
= pad_and_convert_to_numpy(tokens, masked_positions,
263+
masked_labels, pad_id, max_seq_length,
264+
max_seq_length_dec, masked_spans,
265+
bos_id, eos_id, sentinel_tokens)
266+
267+
train_sample = {
268+
'text_enc': tokens_enc,
269+
'text_dec': tokens_dec_in,
270+
'labels': labels,
271+
'loss_mask': loss_mask,
272+
'truncated': int(truncated),
273+
'enc_mask': enc_mask,
274+
'dec_mask': dec_mask,
275+
'enc_dec_mask': enc_dec_mask,
276+
}
204277
return train_sample

megatron/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ class PositionEmbeddingType(enum.Enum):
3333
rotary = 1
3434
absolute = 2
3535
alibi = 3
36+
37+
class UL2ModelType(enum.Enum):
38+
ENCODER_DECODER = 'ED'
39+
NON_CAUSAL_DECODER = 'ND'
40+
CAUSAL_DECODER = 'CD'

0 commit comments

Comments
 (0)