Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UL2 data sampling and pretraining #268

Open
wants to merge 103 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
8c42ed8
Fix typo
janEbert Dec 13, 2022
1e95bf4
Refactor masked LM sampling style selection
janEbert Dec 13, 2022
78559b4
Add more masked LM sampling styles
janEbert Dec 13, 2022
722277a
Allow Prefix-LM style masked LM
janEbert Dec 13, 2022
f6a305f
Add splits string for caching w/ samples mappings
janEbert Jul 4, 2023
c8aeeab
Add UL2 pretraining for T5 model
janEbert Dec 13, 2022
31de228
Refactor span merging
janEbert Dec 13, 2022
891f704
Allow non-causal GPT models
janEbert Dec 13, 2022
cec6b9d
Support UL2 for decoder-only models
janEbert Dec 13, 2022
59ee081
Add custom exceptions
janEbert Dec 14, 2022
b1f7703
Error out on too long sequences
janEbert Dec 14, 2022
4b08245
Remove additional sequence truncation
janEbert Dec 14, 2022
0b103f6
Prefer array-from-list creation
janEbert Dec 14, 2022
290b474
Remove redundant imports
janEbert Jan 3, 2023
1f7ada6
Fix sometimes not inserting prefixes
janEbert Jan 3, 2023
4cef65e
Do not insert `extra_id` tokens for PrefixLM task
janEbert Jan 3, 2023
22c66dc
Document `max_seq_length_dec` argument
janEbert Jan 3, 2023
b1ea793
Skip redundant computations
janEbert Jan 3, 2023
4900916
Fix PrefixLM mean location
janEbert Jan 3, 2023
6e3d25b
Pad decoder-only inputs to same length
janEbert Jan 3, 2023
9872366
Fix decoder-only attention mask shape
janEbert Jan 3, 2023
6a4c22f
Fix `max_ngrams` for normal sampling style
janEbert Jan 23, 2023
9ee32b7
Do not limit `max_predictions_per_seq`
janEbert Jan 23, 2023
5f59d6b
Calculate and use amount of filtered tokens
janEbert Jan 23, 2023
12fd16f
Document normal sampling style
janEbert Jan 23, 2023
0570eda
Fix PrefixLM possible spans calculation
janEbert Jan 23, 2023
f1b81e0
Avoid mutable pointer in arguments
janEbert Jan 23, 2023
e027273
Allow passing callable for getting `model_type`
janEbert Jan 23, 2023
b0747f2
Fix getting model type
janEbert Jan 23, 2023
9d62d4a
Allow recognizing when UL2 is used
janEbert Jan 23, 2023
5befcd5
Only add UL2 tokens if using UL2 pretrain script
janEbert Jan 23, 2023
f0f82b2
Support UL2 tokens for all tokenizers
janEbert Jan 23, 2023
086d482
Add SEP token to GPT tokenizer if using UL2
janEbert Jan 23, 2023
d36e362
Fix enum name
janEbert Jan 23, 2023
2a89915
Fix private UL2 argument default value
janEbert Jan 23, 2023
e84db38
Use binary search for PrefixLM first tail index
janEbert Jan 24, 2023
706a58f
Calculate n-gram indices lazily
janEbert Jan 24, 2023
758c357
Prefer list comprehensions
janEbert Jan 24, 2023
45f3d26
Fix undesired list mutation
janEbert Feb 14, 2023
0dc5bcf
Support `<extra_id>` tokens for GPT tokenizer
janEbert Feb 14, 2023
41c6dd9
Fix tokenizer vocab access
janEbert Feb 14, 2023
6f70180
Revert inheriting from `T5Dataset`
janEbert Feb 14, 2023
da93f93
Fix GPT tokenizer special token handling
janEbert Feb 14, 2023
b1a0456
Allow selectively disabling denoiser token
janEbert Feb 14, 2023
c54a064
Allow not replacing masks with sentinel tokens
janEbert Feb 14, 2023
96bd7e3
Support not adding mask tokens in span corruption
janEbert Feb 14, 2023
c410699
Fix expected number of added tokens
janEbert Feb 15, 2023
0ceaec1
Fix non-masked data
janEbert Feb 16, 2023
1a78e4f
Fix unclear wording
janEbert Feb 16, 2023
ecbafdf
Adjust code style
janEbert Feb 17, 2023
d77ee1e
Fix covered index skipping
janEbert Feb 17, 2023
c6a4346
Prepend objective token before truncating
janEbert Feb 17, 2023
a3d2ec8
Automatically truncate sequences for decoder-only
janEbert Feb 17, 2023
ce8029e
Make `build_index_mappings` public
janEbert Feb 17, 2023
8603d3b
Refactor getting sample
janEbert Feb 17, 2023
4dcdb3d
Add sample packing to T5 dataset
janEbert Feb 17, 2023
c6ba640
Add sample packing to UL2 dataset
janEbert Feb 17, 2023
5a76d30
Fix not supplying `--pack-samples` argument
janEbert Feb 17, 2023
28751c7
Add support for UL2R-style implementation
janEbert Feb 17, 2023
63f280c
Fix T5 dataset packing
janEbert Feb 17, 2023
78000bf
Refactor `get_sample` to return a list
janEbert Feb 22, 2023
8e271cb
Fix T5 sample packing
janEbert Feb 22, 2023
f1bbda7
Fix UL2 sample packing
janEbert Feb 22, 2023
ce3ca2d
Fix desired seq length
janEbert Feb 23, 2023
018641b
Fix padding removal
janEbert Feb 23, 2023
d0bb6dc
Allow packing different denoisers together
janEbert Feb 23, 2023
5edffee
Allow repeating UL2 prompt token when packing
janEbert Feb 23, 2023
6f8e283
Refactor sample packing functions
janEbert Feb 23, 2023
dfe0607
Repeat prompt by default when packing UL2
janEbert Feb 23, 2023
a0b3741
Fix GPT tokenizer vocab size query
janEbert Feb 24, 2023
7eccbf4
Handle possibly empty list
janEbert Feb 24, 2023
909387c
Add optional UL2 normal distribution scaling
janEbert Mar 24, 2023
3aac113
Refactor samples dict creation
janEbert Apr 3, 2023
dc2afba
Move callees under caller
janEbert Apr 3, 2023
4976188
Refactor dummy barriers
janEbert Apr 3, 2023
5ad1ae7
Refactor description creation
janEbert Apr 3, 2023
a10706a
Allow packing only full documents
janEbert Apr 3, 2023
f32c001
Use full-doc packing for T5-style datasets
janEbert Apr 3, 2023
cea3cb8
Fix truncating packed sequences without padding
janEbert Apr 3, 2023
c3b4a72
Fix unconditional usage of non-causal decoder
janEbert May 2, 2023
d3fe05b
Fix decoder-only and no-mask-tokens seq lengths
janEbert Jun 7, 2023
f916fea
Omit second objective token if without mask tokens
janEbert Jun 7, 2023
fdf6249
Do not add separator if S-denoising
janEbert Jun 26, 2023
d1f04f4
Fix number of labels calculation for decoder-only
janEbert Jun 29, 2023
fcbb6b8
Do not automatically add <EOS> token when packing
janEbert Jun 29, 2023
fe4cebf
Fix BlendableDataset size calculation
janEbert Sep 1, 2023
a0fb576
Fix UL2 dataset creation
janEbert Sep 22, 2023
6cfec5c
Allow enabling `_is_ul2` anytime
janEbert Sep 25, 2023
f0d56f8
Remove redundant newline
janEbert Sep 25, 2023
e2c6b21
Fix added tokens expectation for S-denoising
janEbert Sep 27, 2023
553efdc
Add C-denoiser for CLM objective
janEbert Sep 27, 2023
e8c3766
Always remove trailing EOS inputs
janEbert Sep 27, 2023
89ee472
Do not add BOS token for decoder-only
janEbert Sep 27, 2023
87aecd1
Fix BOS token attribute name
janEbert Sep 27, 2023
fee3b2a
Fix BOS token retrieval
janEbert Sep 27, 2023
ff81ad4
Fix BOS token ID usage
janEbert Sep 27, 2023
c07bda3
Fix supplying token ID
janEbert Sep 27, 2023
dc433db
Handle token ID being zero
janEbert Sep 27, 2023
fd4b103
Fix causal targets
janEbert Sep 28, 2023
c48b29d
Fix early return values
janEbert Sep 28, 2023
d3bb961
Fix C-denoiser mask probability
janEbert Sep 28, 2023
b85d383
Fix wrong seq length being used
janEbert Oct 13, 2023
0fd3b63
Fix edge case resulting in empty sequences
janEbert Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from megatron.core.transformer import TransformerConfig

from megatron.model.enums import UL2ModelType

def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
Expand All @@ -34,6 +36,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_ul2_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
Expand Down Expand Up @@ -336,6 +339,17 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False

args.ul2_model_type = UL2ModelType(args.ul2_model_type)
if (
args.ul2_model_type is not UL2ModelType.encoder_decoder
and args.decoder_seq_length is not None
):
print(
f'WARNING: `--decoder_seq_length` is ignored when '
f'`--ul2-model-type` is not '
f'"{UL2ModelType.encoder_decoder.value}"!'
)

if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
Expand Down Expand Up @@ -1137,6 +1151,13 @@ def _add_data_args(parser):
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--no-add-mask-tokens', action='store_false',
help='Whether not to add sentinel tokens for masked '
'spans in span corruption tasks.',
dest='add_mask_tokens')
group.add_argument('--pack-samples', action='store_true',
help='Whether to pack samples in span corruption '
'datasets (T5 or UL2). GPT dataset is always packed.')
group.add_argument('--mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
Expand Down Expand Up @@ -1302,3 +1323,66 @@ def _add_vision_args(parser):
help='warmup teacher temperaure epochs')

return parser


def _add_ul2_args(parser):
group = parser.add_argument_group(title="UL2")

group.add_argument('--ul2-model-type', type=str, default='ED',
choices=['ED', 'ND', 'CD'],
help='What type of model to use for UL2 pretraining. '
'ED = encoder-decoder; ND = non-causal decoder-only; '
'CD = causal decoder-only')
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
default=None,
help='Probability of each denoising objective to be '
'selected. Uniform distribution by default.')
group.add_argument('--ul2-denoisers', nargs='+', type=str,
default=['R', 'R', 'S', 'X', 'X', 'X', 'X'],
choices=['R', 'S', 'X', 'C'],
help='What type of UL2 denoising objective the other '
'UL2 configurations refer to. "C" is a fully causal '
'objective with BOS as its denoiser token. Its '
'settings need to be provided but will be ignored.')
group.add_argument('--ul2-mean-span-lengths', nargs='+', type=float,
default=[3, 8, 0.25, 3, 8, 64, 64],
help='Mean length for sampling span lengths. '
'Numbers < 1 indicate a mean length of the sequence '
'length times that number.')
group.add_argument('--ul2-mask-ratios', nargs='+', type=float,
default=[0.15, 0.15, 0.25, 0.5, 0.5, 0.15, 0.5],
help='Ratio of masked token in the full sequence.')
group.add_argument('--ul2-r-denoiser-token', type=str, default='[R]',
help='What token to prepend for the UL2 R-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-s-denoiser-token', type=str, default='[S]',
help='What token to prepend for the UL2 S-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-x-denoiser-token', type=str, default='[X]',
help='What token to prepend for the UL2 X-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-scale-normal-std', action='store_true',
help='Whether to scale the standard deviation when '
'using a normal distribution for span length sampling.')
group.add_argument('--ul2-like-ul2r', action='store_true',
help='Whether to use the updated implementation as '
'described in the UL2R paper. This only changes the '
'implementation, not the objective configurations!')
group.add_argument('--ul2-pack-any', action='store_true',
help='When `--pack-samples` is also given, whether to '
'pack different denoisers into one sample. If not '
'given, the same denoiser is used for all packed '
'samples.')
group.add_argument('--ul2-pack-no-repeat-prompt', action='store_false',
help='When `--pack-samples` is also given and '
'`--ul2-pack-any` is *not* given, whether to '
'repeat the prompt token for each packed sample.',
dest='ul2_pack_repeat_prompt')
# Has to be `None` by default so it can be overridden by `defaults`
# in `validate_args` but still evaluate to `False`.
group.add_argument('--_is_ul2', help=argparse.SUPPRESS)

return parser
1 change: 1 addition & 0 deletions megatron/core/transformer/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
prefix = 3
3 changes: 2 additions & 1 deletion megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class BertDataset(torch.utils.data.Dataset):

def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
splits_string, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, binary_head):

# Params to store.
Expand All @@ -38,6 +38,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
splits_string,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
Expand Down
Loading