Skip to content

Commit db95ce8

Browse files
committed
Support UL2 for decoder-only models
1 parent b69818d commit db95ce8

File tree

5 files changed

+216
-69
lines changed

5 files changed

+216
-69
lines changed

megatron/arguments.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
import deepspeed
2626

27-
from megatron.enums import PositionEmbeddingType
27+
from megatron.enums import PositionEmbeddingType, UL2ModelType
2828
import megatron
2929
from megatron.logging import log_levels
3030

@@ -311,6 +311,18 @@ def parse_args(extra_args_provider=None, defaults={},
311311
)
312312
args.skip_train_iteration_range = skip_train_iteration_range
313313

314+
args.ul2_model_type = UL2ModelType(args.ul2_model_type)
315+
if (
316+
args.ul2_model_type is not UL2ModelType.ENCODER_DECODER
317+
and args.decoder_seq_length is not None
318+
):
319+
print(
320+
'WARNING: `--decoder_seq_length` is ignored when '
321+
'`--ul2-model-type` is not "',
322+
UL2ModelType.ENCODER_DECODER.value,
323+
'"!'
324+
)
325+
314326
if args.use_bnb_optimizer:
315327
try:
316328
import bitsandbytes as bnb
@@ -1028,6 +1040,11 @@ def _add_vit_args(parser):
10281040
def _add_ul2_args(parser):
10291041
group = parser.add_argument_group(title="UL2")
10301042

1043+
group.add_argument('--ul2-model-type', type=str, default='ED',
1044+
choices=['ED', 'ND', 'CD'],
1045+
help='What type of model to use for UL2 pretraining. '
1046+
'ED = encoder-decoder; ND = non-causal decoder-only; '
1047+
'CD = causal decoder-only')
10311048
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
10321049
default=None,
10331050
help='Probability of each denoising objective to be '

megatron/data/dataset_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def build_dataset(index, name):
597597
args = get_args()
598598
dataset = UL2Dataset(
599599
indexed_dataset=indexed_dataset,
600+
model_type=args.ul2_model_type,
600601
denoiser_ratios=args.ul2_denoiser_ratios,
601602
denoisers=args.ul2_denoisers,
602603
mean_span_lengths=args.ul2_mean_span_lengths,

megatron/data/ul2_dataset.py

+117-44
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
"""UL2-style dataset."""
1717

18+
import math
19+
1820
import numpy as np
1921

2022
from megatron import get_tokenizer
@@ -23,16 +25,34 @@
2325
get_samples_mapping,
2426
SamplingStyle
2527
)
26-
from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset
28+
from megatron.data.t5_dataset import (
29+
make_history_mask,
30+
merge_subsequent_masks,
31+
pad_and_convert_to_numpy,
32+
T5Dataset,
33+
)
34+
from megatron.enums import UL2ModelType
35+
36+
37+
def is_decoder_only(ul2_model_type):
38+
"""Return whether we use a decoder-only model."""
39+
assert isinstance(ul2_model_type, UL2ModelType)
40+
return ul2_model_type is not UL2ModelType.ENCODER_DECODER
41+
42+
43+
def is_prefix_lm(ul2_model_type):
44+
"""Return whether we use a non-causal decoder-only model."""
45+
assert isinstance(ul2_model_type, UL2ModelType)
46+
return ul2_model_type is UL2ModelType.NON_CAUSAL_DECODER
2747

2848

2949
class UL2Dataset(T5Dataset):
3050

3151
def __init__(self, name, indexed_dataset, data_prefix,
32-
num_epochs, max_num_samples, denoiser_ratios,
33-
denoisers, mean_span_lengths, mask_ratios,
34-
denoiser_tokens, max_seq_length, max_seq_length_dec,
35-
short_seq_prob, seed):
52+
num_epochs, max_num_samples, model_type,
53+
denoiser_ratios, denoisers, mean_span_lengths,
54+
mask_ratios, denoiser_tokens, max_seq_length,
55+
max_seq_length_dec, short_seq_prob, seed):
3656

3757
if denoiser_ratios is None:
3858
# Uniform
@@ -49,6 +69,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
4969
# Params to store.
5070
self.name = name
5171
self.seed = seed
72+
self.model_type = model_type
5273
self.denoiser_ratios = [
5374
denoiser_ratio / sum(denoiser_ratios)
5475
for denoiser_ratio in denoiser_ratios
@@ -116,21 +137,21 @@ def __getitem__(self, idx):
116137
self.vocab_id_to_token_dict,
117138
self.cls_ids, self.sep_id,
118139
self.mask_id, self.pad_id,
119-
self.denoiser_ratios, self.denoisers,
120-
self.mean_span_lengths, self.mask_ratios,
121-
np_rng,
122-
self.bos_id, self.eos_id,
123-
self.sentinel_tokens)
140+
self.model_type, self.denoiser_ratios,
141+
self.denoisers, self.mean_span_lengths,
142+
self.mask_ratios, np_rng, self.bos_id,
143+
self.eos_id, self.sentinel_tokens)
124144

125145

126146
def build_training_sample(sample, target_seq_length,
127147
max_seq_length, max_seq_length_dec,
128148
vocab_id_list, vocab_id_to_token_dict,
129149
cls_ids, sep_id, mask_id, pad_id,
130-
denoiser_ratios, denoisers,
131-
mean_span_lengths, mask_ratios,
132-
np_rng, bos_id=None,
133-
eos_id=None, sentinel_tokens=None):
150+
model_type, denoiser_ratios,
151+
denoisers, mean_span_lengths,
152+
mask_ratios, np_rng,
153+
bos_id=None, eos_id=None,
154+
sentinel_tokens=None):
134155
"""Build training sample.
135156
136157
Arguments:
@@ -144,6 +165,7 @@ def build_training_sample(sample, target_seq_length,
144165
sep_id: Separator id.
145166
mask_id: Mask token id.
146167
pad_id: Padding token id.
168+
model_type: What type of model is used.
147169
denoiser_ratios: Probability of each denoising objective to be selected.
148170
denoisers: What type of UL2 denoising objective the other UL2
149171
configurations refer to.
@@ -158,24 +180,28 @@ def build_training_sample(sample, target_seq_length,
158180
sentinel_tokens: unique value to be substituted for every replaced span
159181
"""
160182

183+
# Denoiser selection
184+
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
185+
denoiser = denoisers[denoiser_index]
186+
masked_lm_prob = mask_ratios[denoiser_index]
187+
161188
assert target_seq_length <= max_seq_length
162189

163190
# flatten sentences into one list
164191
tokens = [token for sentence in sample for token in sentence]
165192

166-
# Truncate to `target_sequence_length`.
167193
max_num_tokens = target_seq_length
168-
truncated = len(tokens) > max_num_tokens
169-
tokens = tokens[:max_num_tokens]
170-
171-
# Denoiser selection
172-
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
173-
denoiser = denoisers[denoiser_index]
174-
masked_lm_prob = mask_ratios[denoiser_index]
175-
mean_ngrams = mean_span_lengths[denoiser_index]
176-
if mean_ngrams < 1:
177-
mean_ngrams = round(len(tokens) * mean_ngrams)
178-
max_ngrams = mean_ngrams * 2 - 1
194+
if is_decoder_only(model_type):
195+
# Keep space for repeated `extra_id` tokens; not the most data
196+
# efficient since we calculate this based on the maximum number
197+
# of possible `extra_id` tokens.
198+
safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob))
199+
truncated = len(tokens) > safe_max_seq_len
200+
tokens = tokens[:safe_max_seq_len]
201+
else:
202+
# Truncate to `target_sequence_length`.
203+
truncated = len(tokens) > max_num_tokens
204+
tokens = tokens[:max_num_tokens]
179205

180206
# Prepend objective token.
181207
cls_id = cls_ids.get(denoiser)
@@ -185,6 +211,11 @@ def build_training_sample(sample, target_seq_length,
185211

186212
# Masking.
187213
max_predictions_per_seq = masked_lm_prob * len(tokens)
214+
mean_ngrams = mean_span_lengths[denoiser_index]
215+
if mean_ngrams < 1:
216+
mean_ngrams = round(len(tokens) * mean_ngrams)
217+
max_ngrams = mean_ngrams * 2 - 1
218+
188219
if denoiser == 'R' or denoiser == 'X':
189220
sampling_style = SamplingStyle.NORMAL
190221
prefix_lm = False
@@ -202,22 +233,64 @@ def build_training_sample(sample, target_seq_length,
202233
sampling_style=sampling_style, prefix_lm=prefix_lm,
203234
)
204235

205-
# Padding.
206-
tokens_enc, tokens_dec_in, labels, enc_mask, \
207-
dec_mask, enc_dec_mask, loss_mask \
208-
= pad_and_convert_to_numpy(tokens, masked_positions,
209-
masked_labels, pad_id, max_seq_length,
210-
max_seq_length_dec, masked_spans,
211-
bos_id, eos_id, sentinel_tokens)
212-
213-
train_sample = {
214-
'text_enc': tokens_enc,
215-
'text_dec': tokens_dec_in,
216-
'labels': labels,
217-
'loss_mask': loss_mask,
218-
'truncated': int(truncated),
219-
'enc_mask': enc_mask,
220-
'dec_mask': dec_mask,
221-
'enc_dec_mask': enc_dec_mask,
222-
}
236+
if is_decoder_only(model_type):
237+
# Concatenate to one sequence.
238+
tokens_enc, tokens_dec_in, labels = merge_subsequent_masks(
239+
tokens, masked_spans, bos_id, eos_id, sentinel_tokens)
240+
241+
# Move EOS tokens to end of sequence.
242+
while tokens_enc[-1] == eos_id:
243+
del tokens_enc[-1]
244+
tokens_dec_in.append(eos_id)
245+
labels.append(eos_id)
246+
247+
num_labels = len(labels)
248+
249+
# Move BOS token to start of sequence.
250+
tokens_dec_in = tokens_dec_in[1:]
251+
tokens = np.concatenate([
252+
np.array([bos_id], dtype=np.int64),
253+
tokens_enc,
254+
np.array([sep_id], dtype=np.int64),
255+
tokens_dec_in,
256+
])
257+
labels = np.concatenate([
258+
tokens_enc,
259+
np.array([sep_id], dtype=np.int64),
260+
labels,
261+
])
262+
263+
loss_mask = np.zeros(len(tokens), dtype=np.int64)
264+
loss_mask[-num_labels:] = 1
265+
266+
dec_mask = make_history_mask(tokens)
267+
if is_prefix_lm(model_type):
268+
dec_mask[:-num_labels, :-num_labels] = 1
269+
270+
train_sample = {
271+
'text': tokens,
272+
'labels': labels,
273+
'loss_mask': loss_mask,
274+
'truncated': int(truncated),
275+
'dec_mask': dec_mask,
276+
}
277+
else:
278+
# Padding.
279+
tokens_enc, tokens_dec_in, labels, enc_mask, \
280+
dec_mask, enc_dec_mask, loss_mask \
281+
= pad_and_convert_to_numpy(tokens, masked_positions,
282+
masked_labels, pad_id, max_seq_length,
283+
max_seq_length_dec, masked_spans,
284+
bos_id, eos_id, sentinel_tokens)
285+
286+
train_sample = {
287+
'text_enc': tokens_enc,
288+
'text_dec': tokens_dec_in,
289+
'labels': labels,
290+
'loss_mask': loss_mask,
291+
'truncated': int(truncated),
292+
'enc_mask': enc_mask,
293+
'dec_mask': dec_mask,
294+
'enc_dec_mask': enc_dec_mask,
295+
}
223296
return train_sample

megatron/enums.py

+5
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ class PositionEmbeddingType(enum.Enum):
3333
rotary = 1
3434
absolute = 2
3535
alibi = 3
36+
37+
class UL2ModelType(enum.Enum):
38+
ENCODER_DECODER = 'ED'
39+
NON_CAUSAL_DECODER = 'ND'
40+
CAUSAL_DECODER = 'CD'

0 commit comments

Comments
 (0)