15
15
16
16
"""UL2-style dataset."""
17
17
18
+ import math
19
+
18
20
import numpy as np
19
21
20
22
from megatron import get_tokenizer
23
25
get_samples_mapping ,
24
26
SamplingStyle
25
27
)
26
- from megatron .data .t5_dataset import pad_and_convert_to_numpy , T5Dataset
28
+ from megatron .data .t5_dataset import (
29
+ make_history_mask ,
30
+ merge_subsequent_masks ,
31
+ pad_and_convert_to_numpy ,
32
+ T5Dataset ,
33
+ )
34
+ from megatron .enums import UL2ModelType
35
+
36
+
37
+ def is_decoder_only (ul2_model_type ):
38
+ """Return whether we use a decoder-only model."""
39
+ assert isinstance (ul2_model_type , UL2ModelType )
40
+ return ul2_model_type is not UL2ModelType .ENCODER_DECODER
41
+
42
+
43
+ def is_prefix_lm (ul2_model_type ):
44
+ """Return whether we use a non-causal decoder-only model."""
45
+ assert isinstance (ul2_model_type , UL2ModelType )
46
+ return ul2_model_type is UL2ModelType .NON_CAUSAL_DECODER
27
47
28
48
29
49
class UL2Dataset (T5Dataset ):
30
50
31
51
def __init__ (self , name , indexed_dataset , data_prefix ,
32
- num_epochs , max_num_samples , denoiser_ratios ,
33
- denoisers , mean_span_lengths , mask_ratios ,
34
- denoiser_tokens , max_seq_length , max_seq_length_dec ,
35
- short_seq_prob , seed ):
52
+ num_epochs , max_num_samples , model_type ,
53
+ denoiser_ratios , denoisers , mean_span_lengths ,
54
+ mask_ratios , denoiser_tokens , max_seq_length ,
55
+ max_seq_length_dec , short_seq_prob , seed ):
36
56
37
57
if denoiser_ratios is None :
38
58
# Uniform distribution by default.
@@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
52
72
short_seq_prob , seed )
53
73
54
74
# Params to store.
75
+ self .model_type = model_type
55
76
self .denoiser_ratios = [
56
77
denoiser_ratio / sum (denoiser_ratios )
57
78
for denoiser_ratio in denoiser_ratios
@@ -97,21 +118,21 @@ def __getitem__(self, idx):
97
118
self .vocab_id_to_token_dict ,
98
119
self .cls_ids , self .sep_id ,
99
120
self .mask_id , self .pad_id ,
100
- self .denoiser_ratios , self .denoisers ,
101
- self .mean_span_lengths , self .mask_ratios ,
102
- np_rng ,
103
- self .bos_id , self .eos_id ,
104
- self .sentinel_tokens )
121
+ self .model_type , self .denoiser_ratios ,
122
+ self .denoisers , self .mean_span_lengths ,
123
+ self .mask_ratios , np_rng , self .bos_id ,
124
+ self .eos_id , self .sentinel_tokens )
105
125
106
126
107
127
def build_training_sample (sample , target_seq_length ,
108
128
max_seq_length , max_seq_length_dec ,
109
129
vocab_id_list , vocab_id_to_token_dict ,
110
130
cls_ids , sep_id , mask_id , pad_id ,
111
- denoiser_ratios , denoisers ,
112
- mean_span_lengths , mask_ratios ,
113
- np_rng , bos_id = None ,
114
- eos_id = None , sentinel_tokens = None ):
131
+ model_type , denoiser_ratios ,
132
+ denoisers , mean_span_lengths ,
133
+ mask_ratios , np_rng ,
134
+ bos_id = None , eos_id = None ,
135
+ sentinel_tokens = None ):
115
136
"""Build training sample.
116
137
117
138
Arguments:
@@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length,
125
146
sep_id: Separator id.
126
147
mask_id: Mask token id.
127
148
pad_id: Padding token id.
149
+ model_type: What type of model is used.
128
150
denoiser_ratios: Probability of each denoising objective to be selected.
129
151
denoisers: What type of UL2 denoising objective the other UL2
130
152
configurations refer to.
@@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length,
139
161
sentinel_tokens: unique value to be substituted for every replaced span
140
162
"""
141
163
164
+ # Denoiser selection
165
+ denoiser_index = np_rng .choice (np .arange (len (denoisers )), p = denoiser_ratios )
166
+ denoiser = denoisers [denoiser_index ]
167
+ masked_lm_prob = mask_ratios [denoiser_index ]
168
+
142
169
assert target_seq_length <= max_seq_length
143
170
144
171
# flatten sentences into one list
145
172
tokens = [token for sentence in sample for token in sentence ]
146
173
147
- # Truncate to `target_sequence_length`.
148
174
max_num_tokens = target_seq_length
149
- truncated = len ( tokens ) > max_num_tokens
150
- tokens = tokens [: max_num_tokens ]
151
-
152
- # Denoiser selection
153
- denoiser_index = np_rng . choice ( np . arange ( len ( denoisers )), p = denoiser_ratios )
154
- denoiser = denoisers [ denoiser_index ]
155
- masked_lm_prob = mask_ratios [ denoiser_index ]
156
- mean_ngrams = mean_span_lengths [ denoiser_index ]
157
- if mean_ngrams < 1 :
158
- mean_ngrams = round ( len (tokens ) * mean_ngrams )
159
- max_ngrams = mean_ngrams * 2 - 1
175
+ if is_decoder_only ( model_type ):
176
+ # Keep space for repeated `extra_id` tokens; not the most data
177
+ # efficient since we calculate this based on the maximum number
178
+ # of possible `extra_id` tokens.
179
+ safe_max_seq_len = math . floor ( max_num_tokens / ( 1 + masked_lm_prob ) )
180
+ truncated = len ( tokens ) > safe_max_seq_len
181
+ tokens = tokens [: safe_max_seq_len ]
182
+ else :
183
+ # Truncate to `target_sequence_length`.
184
+ truncated = len (tokens ) > max_num_tokens
185
+ tokens = tokens [: max_num_tokens ]
160
186
161
187
# Prepend objective token.
162
188
cls_id = cls_ids .get (denoiser )
@@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length,
166
192
167
193
# Masking.
168
194
max_predictions_per_seq = masked_lm_prob * len (tokens )
195
+ mean_ngrams = mean_span_lengths [denoiser_index ]
196
+ if mean_ngrams < 1 :
197
+ mean_ngrams = round (len (tokens ) * mean_ngrams )
198
+ max_ngrams = mean_ngrams * 2 - 1
199
+
169
200
if denoiser == 'R' or denoiser == 'X' :
170
201
sampling_style = SamplingStyle .NORMAL
171
202
prefix_lm = False
@@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length,
183
214
sampling_style = sampling_style , prefix_lm = prefix_lm ,
184
215
)
185
216
186
- # Padding.
187
- tokens_enc , tokens_dec_in , labels , enc_mask , \
188
- dec_mask , enc_dec_mask , loss_mask \
189
- = pad_and_convert_to_numpy (tokens , masked_positions ,
190
- masked_labels , pad_id , max_seq_length ,
191
- max_seq_length_dec , masked_spans ,
192
- bos_id , eos_id , sentinel_tokens )
193
-
194
- train_sample = {
195
- 'text_enc' : tokens_enc ,
196
- 'text_dec' : tokens_dec_in ,
197
- 'labels' : labels ,
198
- 'loss_mask' : loss_mask ,
199
- 'truncated' : int (truncated ),
200
- 'enc_mask' : enc_mask ,
201
- 'dec_mask' : dec_mask ,
202
- 'enc_dec_mask' : enc_dec_mask ,
203
- }
217
+ if is_decoder_only (model_type ):
218
+ # Concatenate to one sequence.
219
+ tokens_enc , tokens_dec_in , labels = merge_subsequent_masks (
220
+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
221
+
222
+ # Move EOS tokens to end of sequence.
223
+ while tokens_enc [- 1 ] == eos_id :
224
+ del tokens_enc [- 1 ]
225
+ tokens_dec_in .append (eos_id )
226
+ labels .append (eos_id )
227
+
228
+ num_labels = len (labels )
229
+
230
+ # Move BOS token to start of sequence.
231
+ tokens_dec_in = tokens_dec_in [1 :]
232
+ tokens = np .concatenate ([
233
+ np .array ([bos_id ], dtype = np .int64 ),
234
+ tokens_enc ,
235
+ np .array ([sep_id ], dtype = np .int64 ),
236
+ tokens_dec_in ,
237
+ ])
238
+ labels = np .concatenate ([
239
+ tokens_enc ,
240
+ np .array ([sep_id ], dtype = np .int64 ),
241
+ labels ,
242
+ ])
243
+
244
+ loss_mask = np .zeros (len (tokens ), dtype = np .int64 )
245
+ loss_mask [- num_labels :] = 1
246
+
247
+ dec_mask = make_history_mask (tokens )
248
+ if is_prefix_lm (model_type ):
249
+ dec_mask [:- num_labels , :- num_labels ] = 1
250
+
251
+ train_sample = {
252
+ 'text' : tokens ,
253
+ 'labels' : labels ,
254
+ 'loss_mask' : loss_mask ,
255
+ 'truncated' : int (truncated ),
256
+ 'dec_mask' : dec_mask ,
257
+ }
258
+ else :
259
+ # Padding.
260
+ tokens_enc , tokens_dec_in , labels , enc_mask , \
261
+ dec_mask , enc_dec_mask , loss_mask \
262
+ = pad_and_convert_to_numpy (tokens , masked_positions ,
263
+ masked_labels , pad_id , max_seq_length ,
264
+ max_seq_length_dec , masked_spans ,
265
+ bos_id , eos_id , sentinel_tokens )
266
+
267
+ train_sample = {
268
+ 'text_enc' : tokens_enc ,
269
+ 'text_dec' : tokens_dec_in ,
270
+ 'labels' : labels ,
271
+ 'loss_mask' : loss_mask ,
272
+ 'truncated' : int (truncated ),
273
+ 'enc_mask' : enc_mask ,
274
+ 'dec_mask' : dec_mask ,
275
+ 'enc_dec_mask' : enc_dec_mask ,
276
+ }
204
277
return train_sample
0 commit comments