15
15
16
16
"""UL2-style dataset."""
17
17
18
+ import math
19
+
18
20
import numpy as np
19
21
20
22
from megatron import get_tokenizer
23
25
get_samples_mapping ,
24
26
SamplingStyle
25
27
)
26
- from megatron .data .t5_dataset import pad_and_convert_to_numpy , T5Dataset
28
+ from megatron .data .t5_dataset import (
29
+ make_history_mask ,
30
+ merge_subsequent_masks ,
31
+ pad_and_convert_to_numpy ,
32
+ T5Dataset ,
33
+ )
34
+ from megatron .enums import UL2ModelType
35
+
36
+
37
+ def is_decoder_only (ul2_model_type ):
38
+ """Return whether we use a decoder-only model."""
39
+ assert isinstance (ul2_model_type , UL2ModelType )
40
+ return ul2_model_type is not UL2ModelType .ENCODER_DECODER
41
+
42
+
43
+ def is_prefix_lm (ul2_model_type ):
44
+ """Return whether we use a non-causal decoder-only model."""
45
+ assert isinstance (ul2_model_type , UL2ModelType )
46
+ return ul2_model_type is UL2ModelType .NON_CAUSAL_DECODER
27
47
28
48
29
49
class UL2Dataset (T5Dataset ):
30
50
31
51
def __init__ (self , name , indexed_dataset , data_prefix ,
32
- num_epochs , max_num_samples , denoiser_ratios ,
33
- denoisers , mean_span_lengths , mask_ratios ,
34
- denoiser_tokens , max_seq_length , max_seq_length_dec ,
35
- short_seq_prob , seed ):
52
+ num_epochs , max_num_samples , model_type ,
53
+ denoiser_ratios , denoisers , mean_span_lengths ,
54
+ mask_ratios , denoiser_tokens , max_seq_length ,
55
+ max_seq_length_dec , short_seq_prob , seed ):
36
56
37
57
if denoiser_ratios is None :
38
58
# Uniform
@@ -49,6 +69,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
49
69
# Params to store.
50
70
self .name = name
51
71
self .seed = seed
72
+ self .model_type = model_type
52
73
self .denoiser_ratios = [
53
74
denoiser_ratio / sum (denoiser_ratios )
54
75
for denoiser_ratio in denoiser_ratios
@@ -116,21 +137,21 @@ def __getitem__(self, idx):
116
137
self .vocab_id_to_token_dict ,
117
138
self .cls_ids , self .sep_id ,
118
139
self .mask_id , self .pad_id ,
119
- self .denoiser_ratios , self .denoisers ,
120
- self .mean_span_lengths , self .mask_ratios ,
121
- np_rng ,
122
- self .bos_id , self .eos_id ,
123
- self .sentinel_tokens )
140
+ self .model_type , self .denoiser_ratios ,
141
+ self .denoisers , self .mean_span_lengths ,
142
+ self .mask_ratios , np_rng , self .bos_id ,
143
+ self .eos_id , self .sentinel_tokens )
124
144
125
145
126
146
def build_training_sample (sample , target_seq_length ,
127
147
max_seq_length , max_seq_length_dec ,
128
148
vocab_id_list , vocab_id_to_token_dict ,
129
149
cls_ids , sep_id , mask_id , pad_id ,
130
- denoiser_ratios , denoisers ,
131
- mean_span_lengths , mask_ratios ,
132
- np_rng , bos_id = None ,
133
- eos_id = None , sentinel_tokens = None ):
150
+ model_type , denoiser_ratios ,
151
+ denoisers , mean_span_lengths ,
152
+ mask_ratios , np_rng ,
153
+ bos_id = None , eos_id = None ,
154
+ sentinel_tokens = None ):
134
155
"""Build training sample.
135
156
136
157
Arguments:
@@ -144,6 +165,7 @@ def build_training_sample(sample, target_seq_length,
144
165
sep_id: Separator id.
145
166
mask_id: Mask token id.
146
167
pad_id: Padding token id.
168
+ model_type: What type of model is used.
147
169
denoiser_ratios: Probability of each denoising objective to be selected.
148
170
denoisers: What type of UL2 denoising objective the other UL2
149
171
configurations refer to.
@@ -158,24 +180,28 @@ def build_training_sample(sample, target_seq_length,
158
180
sentinel_tokens: unique value to be substituted for every replaced span
159
181
"""
160
182
183
+ # Denoiser selection
184
+ denoiser_index = np_rng .choice (np .arange (len (denoisers )), p = denoiser_ratios )
185
+ denoiser = denoisers [denoiser_index ]
186
+ masked_lm_prob = mask_ratios [denoiser_index ]
187
+
161
188
assert target_seq_length <= max_seq_length
162
189
163
190
# flatten sentences into one list
164
191
tokens = [token for sentence in sample for token in sentence ]
165
192
166
- # Truncate to `target_sequence_length`.
167
193
max_num_tokens = target_seq_length
168
- truncated = len ( tokens ) > max_num_tokens
169
- tokens = tokens [: max_num_tokens ]
170
-
171
- # Denoiser selection
172
- denoiser_index = np_rng . choice ( np . arange ( len ( denoisers )), p = denoiser_ratios )
173
- denoiser = denoisers [ denoiser_index ]
174
- masked_lm_prob = mask_ratios [ denoiser_index ]
175
- mean_ngrams = mean_span_lengths [ denoiser_index ]
176
- if mean_ngrams < 1 :
177
- mean_ngrams = round ( len (tokens ) * mean_ngrams )
178
- max_ngrams = mean_ngrams * 2 - 1
194
+ if is_decoder_only ( model_type ):
195
+ # Keep space for repeated `extra_id` tokens; not the most data
196
+ # efficient since we calculate this based on the maximum number
197
+ # of possible `extra_id` tokens.
198
+ safe_max_seq_len = math . floor ( max_num_tokens / ( 1 + masked_lm_prob ) )
199
+ truncated = len ( tokens ) > safe_max_seq_len
200
+ tokens = tokens [: safe_max_seq_len ]
201
+ else :
202
+ # Truncate to `target_sequence_length`.
203
+ truncated = len (tokens ) > max_num_tokens
204
+ tokens = tokens [: max_num_tokens ]
179
205
180
206
# Prepend objective token.
181
207
cls_id = cls_ids .get (denoiser )
@@ -185,6 +211,11 @@ def build_training_sample(sample, target_seq_length,
185
211
186
212
# Masking.
187
213
max_predictions_per_seq = masked_lm_prob * len (tokens )
214
+ mean_ngrams = mean_span_lengths [denoiser_index ]
215
+ if mean_ngrams < 1 :
216
+ mean_ngrams = round (len (tokens ) * mean_ngrams )
217
+ max_ngrams = mean_ngrams * 2 - 1
218
+
188
219
if denoiser == 'R' or denoiser == 'X' :
189
220
sampling_style = SamplingStyle .NORMAL
190
221
prefix_lm = False
@@ -202,22 +233,64 @@ def build_training_sample(sample, target_seq_length,
202
233
sampling_style = sampling_style , prefix_lm = prefix_lm ,
203
234
)
204
235
205
- # Padding.
206
- tokens_enc , tokens_dec_in , labels , enc_mask , \
207
- dec_mask , enc_dec_mask , loss_mask \
208
- = pad_and_convert_to_numpy (tokens , masked_positions ,
209
- masked_labels , pad_id , max_seq_length ,
210
- max_seq_length_dec , masked_spans ,
211
- bos_id , eos_id , sentinel_tokens )
212
-
213
- train_sample = {
214
- 'text_enc' : tokens_enc ,
215
- 'text_dec' : tokens_dec_in ,
216
- 'labels' : labels ,
217
- 'loss_mask' : loss_mask ,
218
- 'truncated' : int (truncated ),
219
- 'enc_mask' : enc_mask ,
220
- 'dec_mask' : dec_mask ,
221
- 'enc_dec_mask' : enc_dec_mask ,
222
- }
236
+ if is_decoder_only (model_type ):
237
+ # Concatenate to one sequence.
238
+ tokens_enc , tokens_dec_in , labels = merge_subsequent_masks (
239
+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
240
+
241
+ # Move EOS tokens to end of sequence.
242
+ while tokens_enc [- 1 ] == eos_id :
243
+ del tokens_enc [- 1 ]
244
+ tokens_dec_in .append (eos_id )
245
+ labels .append (eos_id )
246
+
247
+ num_labels = len (labels )
248
+
249
+ # Move BOS token to start of sequence.
250
+ tokens_dec_in = tokens_dec_in [1 :]
251
+ tokens = np .concatenate ([
252
+ np .array ([bos_id ], dtype = np .int64 ),
253
+ tokens_enc ,
254
+ np .array ([sep_id ], dtype = np .int64 ),
255
+ tokens_dec_in ,
256
+ ])
257
+ labels = np .concatenate ([
258
+ tokens_enc ,
259
+ np .array ([sep_id ], dtype = np .int64 ),
260
+ labels ,
261
+ ])
262
+
263
+ loss_mask = np .zeros (len (tokens ), dtype = np .int64 )
264
+ loss_mask [- num_labels :] = 1
265
+
266
+ dec_mask = make_history_mask (tokens )
267
+ if is_prefix_lm (model_type ):
268
+ dec_mask [:- num_labels , :- num_labels ] = 1
269
+
270
+ train_sample = {
271
+ 'text' : tokens ,
272
+ 'labels' : labels ,
273
+ 'loss_mask' : loss_mask ,
274
+ 'truncated' : int (truncated ),
275
+ 'dec_mask' : dec_mask ,
276
+ }
277
+ else :
278
+ # Padding.
279
+ tokens_enc , tokens_dec_in , labels , enc_mask , \
280
+ dec_mask , enc_dec_mask , loss_mask \
281
+ = pad_and_convert_to_numpy (tokens , masked_positions ,
282
+ masked_labels , pad_id , max_seq_length ,
283
+ max_seq_length_dec , masked_spans ,
284
+ bos_id , eos_id , sentinel_tokens )
285
+
286
+ train_sample = {
287
+ 'text_enc' : tokens_enc ,
288
+ 'text_dec' : tokens_dec_in ,
289
+ 'labels' : labels ,
290
+ 'loss_mask' : loss_mask ,
291
+ 'truncated' : int (truncated ),
292
+ 'enc_mask' : enc_mask ,
293
+ 'dec_mask' : dec_mask ,
294
+ 'enc_dec_mask' : enc_dec_mask ,
295
+ }
223
296
return train_sample
0 commit comments