Skip to content

Add UL2 data sampling and pretraining #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 122 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
b2fc665
Fix `PretrainedFromHF` tokenizer with T5 training
janEbert Dec 12, 2022
13becf1
Allow passing existing casual attention masks
janEbert Dec 12, 2022
7f50532
Refactor masked LM sampling style selection
janEbert Dec 12, 2022
d8db189
Add more masked LM sampling styles
janEbert Dec 12, 2022
006c4e9
Allow Prefix-LM style masked LM
janEbert Dec 12, 2022
f802317
Add UL2 pretraining for T5 model
janEbert Dec 12, 2022
deed87f
Refactor span merging
janEbert Dec 13, 2022
728e076
Support UL2 for decoder-only models
janEbert Dec 13, 2022
42ece6b
Unconditionally use safe maximum sequence length
janEbert Dec 13, 2022
d18f84e
Add custom exceptions
janEbert Dec 14, 2022
fa5aa68
Error out on too long sequences
janEbert Dec 14, 2022
c7d8a8b
Remove additional sequence truncation
janEbert Dec 14, 2022
c722516
Prefer array-from-list creation
janEbert Dec 14, 2022
69f6e70
Remove redundant imports
janEbert Jan 2, 2023
f08a104
Fix not inserting prefixes
janEbert Jan 3, 2023
d2fd03e
Do not insert `extra_id` tokens for PrefixLM task
janEbert Jan 3, 2023
daf52cc
Document `max_seq_length_dec` argument
janEbert Jan 3, 2023
04be590
Skip redundant computations
janEbert Jan 3, 2023
7bc5a87
Fix PrefixLM mean location
janEbert Jan 3, 2023
775e99d
Pad decoder-only inputs to same length
janEbert Jan 3, 2023
538c30b
Fix decoder-only attention mask shape
janEbert Jan 3, 2023
ba4476c
Document index set selection for PrefixLM masking
janEbert Jan 23, 2023
678fbdc
Fix `max_ngrams` for normal sampling style
janEbert Jan 23, 2023
00479e5
Do not limit `max_predictions_per_seq`
janEbert Jan 23, 2023
795caef
Calculate and use amount of filtered tokens
janEbert Jan 23, 2023
689e15f
Document normal sampling style
janEbert Jan 23, 2023
e44d0e4
Fix PrefixLM possible spans calculation
janEbert Jan 23, 2023
075f05f
Use binary search for PrefixLM first tail index
janEbert Jan 24, 2023
6bc7471
Calculate n-gram indices lazily
janEbert Jan 24, 2023
a105f32
Fix code style
janEbert Jan 24, 2023
f0fe282
Prefer list comprehensions
janEbert Jan 24, 2023
11bd6db
Allow recognizing when UL2 is used
janEbert Feb 14, 2023
43eee93
Support UL2 tokens for all tokenizers
janEbert Feb 14, 2023
6686f04
Support `<extra_id>` tokens for GPT tokenizer
janEbert Feb 14, 2023
f6128c6
Fix tokenizer vocab access
janEbert Feb 14, 2023
8f48763
Revert inheriting from `T5Dataset`
janEbert Feb 14, 2023
7f99a12
Fix GPT tokenizer special token handling
janEbert Feb 14, 2023
535a306
Do inherit from `torch.utils.data.Dataset`
janEbert Feb 14, 2023
db623b3
Add whitespace
janEbert Feb 14, 2023
ef72280
Allow selectively disabling denoiser token
janEbert Feb 14, 2023
001b50c
Allow not replacing masks with sentinel tokens
janEbert Feb 14, 2023
23c052f
Support not adding mask tokens in span corruption
janEbert Feb 14, 2023
0f4fd3f
Fix expected number of added tokens
janEbert Feb 15, 2023
da1f4e9
Fix non-masked data
janEbert Feb 16, 2023
55320ea
Fix unclear wording
janEbert Feb 16, 2023
5d27b27
Adjust code style
janEbert Feb 16, 2023
23181ab
Fix covered index skipping
janEbert Feb 17, 2023
6032cc6
Prepend objective token before truncating
janEbert Feb 17, 2023
c9c336f
Automatically truncate sequences for decoder-only
janEbert Feb 17, 2023
b8003cb
Fix covered span skipping fix
janEbert Feb 17, 2023
e3d91a6
Make `build_index_mappings` public
janEbert Feb 17, 2023
e61e78f
Refactor getting sample
janEbert Feb 17, 2023
c3b0a55
Add sample packing to T5 dataset
janEbert Feb 17, 2023
c4d748b
Add sample packing to UL2 dataset
janEbert Feb 17, 2023
689b57e
Fix typo and comment placement
janEbert Feb 17, 2023
af204e7
Fix not supplying `--pack-samples` argument
janEbert Feb 17, 2023
78eb035
Add support for UL2R-style implementation
janEbert Feb 17, 2023
c03eed4
Fix T5 dataset packing
janEbert Feb 17, 2023
9e84f06
Refactor `get_sample` to return a list
janEbert Feb 22, 2023
5e2b4f5
Fix T5 sample packing
janEbert Feb 22, 2023
e2a0c36
Fix UL2 sample packing
janEbert Feb 22, 2023
c2884c8
Refactor samples dict creation
janEbert Feb 22, 2023
7eb7923
Fix desired seq length
janEbert Feb 23, 2023
dd4c0d0
Fix padding removal
janEbert Feb 23, 2023
58148f8
Allow repeating UL2 prompt token when packing
janEbert Feb 23, 2023
c41fecd
Allow packing different denoisers together
janEbert Feb 23, 2023
057bb47
Refactor sample packing functions
janEbert Feb 23, 2023
e2062b7
Repeat prompt by default when packing UL2
janEbert Feb 23, 2023
d31b89f
Support pipelining for decoder-only model
janEbert Feb 23, 2023
17dca4f
Fix GPT tokenizer vocab size query
janEbert Feb 24, 2023
bf9b1eb
Handle possibly empty list
janEbert Feb 24, 2023
c4aa4cd
Fix no newline at EOF
janEbert Feb 27, 2023
8d7a0df
Allow full prefix Prefix-LM attention sampling
janEbert Feb 27, 2023
9bd6e1e
Support PrefixLM models
janEbert Feb 27, 2023
ba4ab49
Allow setting number of few-shot examples
janEbert Feb 27, 2023
9f53171
Update task/dataset name
janEbert Feb 27, 2023
5b63d0b
Do not remove last token
janEbert Feb 28, 2023
639b71d
Fix PrefixLM contexts
janEbert Feb 28, 2023
127d1e4
Fix module refactor
janEbert Feb 28, 2023
1bb788d
Fix possible `TypeError`
janEbert Feb 28, 2023
cf5965a
Optionally add prefix tokens
janEbert Feb 28, 2023
a538238
Automatically add UL2 tokens
janEbert Feb 28, 2023
3a8bc35
Fix context lengths batch chunking
janEbert Mar 1, 2023
6f0e33a
Allow different models to be loaded
janEbert Mar 1, 2023
9c4c718
Fix context batch size padding
janEbert Mar 2, 2023
754cf21
Add xPos embeddings
janEbert Mar 7, 2023
08b0eaf
Add optional UL2 normal distribution scaling
janEbert Mar 7, 2023
15622d2
Allow evaluating encoder-decoder models
janEbert Mar 7, 2023
e5a6169
Fix not passing `scale_normal_std`
janEbert Mar 8, 2023
d583fe9
Add T5-style GLU layers
janEbert Mar 7, 2023
ad7de7e
Rename xPos embedding class
janEbert Mar 9, 2023
81a68f7
Integrate xPos embedding
janEbert Mar 9, 2023
46e145d
Handle xPos embedding
janEbert Mar 9, 2023
482f0ea
Do not use bias for 2nd MLP layer if using T5 GLU
janEbert Mar 9, 2023
4385f7b
Fix T5 GLU constructor arguments
janEbert Mar 9, 2023
2d24b13
Refactor samples dict creation
janEbert Mar 9, 2023
bd461f5
Move callees under caller
janEbert Mar 9, 2023
35b2956
Handle empty context
janEbert Mar 10, 2023
f0171e0
Handle more possible model types
janEbert Mar 10, 2023
92158d8
Fix fully truncated contexts with prefix tokens
janEbert Mar 10, 2023
3b7692f
Make T5 GLU checks safer
janEbert Mar 10, 2023
b37d3ee
Improve import code style
janEbert Mar 20, 2023
5959e89
Refactor dummy barriers
janEbert Mar 20, 2023
ce8c1a5
Refactor file name creation
janEbert Mar 20, 2023
3e52966
Allow packing only full documents
janEbert Mar 20, 2023
23efa88
Use full-doc packing for T5-style datasets
janEbert Mar 20, 2023
88eb98a
Fix trying to all-reduce non-existent bias
janEbert Mar 20, 2023
59e8451
Fix truncating packed sequences without padding
janEbert Mar 21, 2023
24d46ff
Speed up packed dataset indexing
janEbert Mar 24, 2023
600542d
Try to exit padding removal early
janEbert Apr 3, 2023
58831d2
Fix xPos embedding
janEbert Apr 4, 2023
fe45cea
Fix padding loss mask
janEbert Apr 13, 2023
15e7b98
Handle failure mode regarding non-DS checkpoints
janEbert Apr 13, 2023
ae45a9e
Fix decoder-only and no-mask-tokens seq lengths
janEbert Jun 7, 2023
0c91b96
Omit second objective token if without mask tokens
janEbert Jun 7, 2023
0c246c4
Fix NumPy deprecations
janEbert Jun 7, 2023
7ce8635
Fix supplied arguments
janEbert Jun 26, 2023
7290181
Do not add separator if S-denoising
janEbert Jun 26, 2023
628d847
Fix caching error
janEbert May 12, 2023
9c727e7
Fix number of labels calculation for decoder-only
janEbert Jun 29, 2023
4ffa951
Do not automatically add <EOS> token when packing
janEbert Jun 29, 2023
ff5787e
Allow silently ignoring causal attention mask
janEbert Jun 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/run_evalharness_deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Also make sure `data` is not on one of the limited paritions like WORKSF.
Then install datasets for the tasks:
```
python ./tasks/eval_harness/download.py --task_list
arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc
arc_challenge,arc_easy,boolq,copa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc
```
and make sure that `export HF_DATASETS_OFFLINE=1`

Expand Down
2 changes: 1 addition & 1 deletion examples/run_evalharness_deepspeed.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ CMD="./tasks/eval_harness/evaluate.py \
--seq-length $SEQ_LEN \
--adaptive_seq_len \
--eval_fp32 \
--task_list arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sst,webqs,wic,winogrande,wnli,wsc,triviaqa,sciq \
--task_list arc_challenge,arc_easy,boolq,copa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sst,webqs,wic,winogrande,wnli,wsc,triviaqa,sciq \
$MEGATRON_REQUIRED_ARGS \
"

Expand Down
2 changes: 1 addition & 1 deletion examples/run_evalharness_tr11-176b-ml.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ CMD="./tasks/eval_harness/evaluate.py \
--bf16 \
--inference \
--seq-length $SEQ_LEN \
--task_list arc_challenge,arc_easy,boolq,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc \
--task_list arc_challenge,arc_easy,boolq,copa,headqa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc \
--deepspeed \
--deepspeed_config ds_config.json \
--bootstrap_iters 2 \
Expand Down
6 changes: 5 additions & 1 deletion finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
if args.position_embedding_type not in [
PositionEmbeddingType.alibi,
PositionEmbeddingType.rotary,
PositionEmbeddingType.xpos,
]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

return (tokens, position_ids, attention_mask), (labels, loss_mask)
Expand Down
81 changes: 79 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import deepspeed

from megatron.enums import PositionEmbeddingType
from megatron.enums import PositionEmbeddingType, UL2ModelType
import megatron
from megatron.logging import log_levels

Expand All @@ -49,6 +49,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_ul2_args(parser)
parser = _add_logging_args(parser)
parser = _add_zero_args(parser)
parser = _add_memoryopt_args(parser)
Expand Down Expand Up @@ -310,6 +311,17 @@ def parse_args(extra_args_provider=None, defaults={},
)
args.skip_train_iteration_range = skip_train_iteration_range

args.ul2_model_type = UL2ModelType(args.ul2_model_type)
if (
args.ul2_model_type is not UL2ModelType.ENCODER_DECODER
and args.decoder_seq_length is not None
):
print(
f'WARNING: `--decoder_seq_length` is ignored when '
f'`--ul2-model-type` is not '
f'"{UL2ModelType.ENCODER_DECODER.value}"!'
)

if args.use_bnb_optimizer:
try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -398,7 +410,7 @@ def _add_network_size_args(parser):
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
choices=list(PositionEmbeddingType),
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.'
help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.'
)
group.add_argument('--glu-activation', type=str,
choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(),
Expand Down Expand Up @@ -901,6 +913,13 @@ def __call__(self, parser, args, values, option_string=None):
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--no-add-mask-tokens', action='store_false',
help='Whether not to add sentinel tokens for masked '
'spans in span corruption tasks.',
dest='add_mask_tokens')
group.add_argument('--pack-samples', action='store_true',
help='Whether to pack samples in span corruption '
'datasets (T5 or UL2). GPT dataset is always packed.')
group.add_argument('--mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
Expand Down Expand Up @@ -1024,6 +1043,64 @@ def _add_vit_args(parser):
return parser


def _add_ul2_args(parser):
group = parser.add_argument_group(title="UL2")

group.add_argument('--ul2-model-type', type=str, default='ED',
choices=['ED', 'ND', 'CD'],
help='What type of model to use for UL2 pretraining. '
'ED = encoder-decoder; ND = non-causal decoder-only; '
'CD = causal decoder-only')
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
default=None,
help='Probability of each denoising objective to be '
'selected. Uniform distribution by default.')
group.add_argument('--ul2-denoisers', nargs='+', type=str,
default=['R', 'R', 'S', 'X', 'X', 'X', 'X'],
choices=['R', 'S', 'X'],
help='What type of UL2 denoising objective the other '
'UL2 configurations refer to.')
group.add_argument('--ul2-mean-span-lengths', nargs='+', type=float,
default=[3, 8, 0.25, 3, 8, 64, 64],
help='Mean length for sampling span lengths. '
'Numbers < 1 indicate a mean length of the sequence '
'length times that number.')
group.add_argument('--ul2-mask-ratios', nargs='+', type=float,
default=[0.15, 0.15, 0.25, 0.5, 0.5, 0.15, 0.5],
help='Ratio of masked token in the full sequence.')
group.add_argument('--ul2-r-denoiser-token', type=str, default='[R]',
help='What token to prepend for the UL2 R-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-s-denoiser-token', type=str, default='[S]',
help='What token to prepend for the UL2 S-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-x-denoiser-token', type=str, default='[X]',
help='What token to prepend for the UL2 X-denoising '
'objective. If empty, do not prepend a token for this '
'objective.')
group.add_argument('--ul2-scale-normal-std', action='store_true',
help='Whether to scale the standard deviation when '
'using a normal distribution for span length sampling.')
group.add_argument('--ul2-like-ul2r', action='store_true',
help='Whether to use the updated implementation as '
'described in the UL2R paper. This only changes the '
'implementation, not the objective configurations!')
group.add_argument('--ul2-pack-any', action='store_true',
help='When `--pack-samples` is also given, whether to '
'pack different denoisers into one sample. If not '
'given, the same denoiser is used for all packed '
'samples.')
group.add_argument('--ul2-pack-no-repeat-prompt', action='store_false',
help='When `--pack-samples` is also given and '
'`--ul2-pack-any` is *not* given, whether to '
'repeat the prompt token for each packed sample.',
dest='ul2_pack_repeat_prompt')

return parser


def _add_zero_args(parser):
"""Text generate arguments."""

Expand Down
Loading