-
Notifications
You must be signed in to change notification settings - Fork 26
[WIP]Fine tune with LayerSkip #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ariG23498 ! Looks great and I like the breakdown.
I have made some comments and suggestions.
fine-tune.py
Outdated
super().__init__(*args, **kwargs) | ||
self.model = model | ||
self.num_layers = model.config.num_hidden_layers | ||
self.early_exit_layer = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to do it in this PR, but in the future, self.early_exit_layer
could be a list of layers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting idea. If I am understanding this correctly you mean we could have a list of indices for the layer we want to exit early self.early_exit_layers=[0, 4, 8]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. That is actually what we referred to as "rotational curriculum" in the paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can create an issue and then do it in another PR if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other things we might need:
- add a line to save checkpoint to file (the path could be a CLI argument as well)
- update README with example command to train a model
Leads to: ``` ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Orig Time: 0.8662526607513428 From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`. ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.5897870063781738 For Layer: 1 Speedup: 1.4687550783305965 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.35465383529663086 For Layer: 2 Speedup: 2.4425300801464984 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.36078310012817383 For Layer: 3 Speedup: 2.4010344731878877 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3151836395263672 For Layer: 4 Speedup: 2.748406173788329 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3399159908294678 For Layer: 5 Speedup: 2.5484316246420207 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3960268497467041 For Layer: 6 Speedup: 2.1873584109395403 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.38035011291503906 For Layer: 7 Speedup: 2.2775138782326128 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4112529754638672 For Layer: 8 Speedup: 2.106374208658952 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3891477584838867 For Layer: 9 Speedup: 2.226025055691568 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.41824865341186523 For Layer: 10 Speedup: 2.0711427369457924 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.43691492080688477 For Layer: 11 Speedup: 1.9826575369675328 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.44853711128234863 For Layer: 12 Speedup: 1.931284254885316 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.45146870613098145 For Layer: 13 Speedup: 1.9187435341310743 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4784407615661621 For Layer: 14 Speedup: 1.8105745378292801 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4980161190032959 For Layer: 15 Speedup: 1.7394068739883695 ```
README.md
Outdated
## Fine Tune | ||
|
||
To train any supported HuggingFace model with the LayerSkip approach: | ||
```bash | ||
torchrun finetune_layerskip.py \ | ||
--ckpt facebook/llama2-7B \ | ||
--ds_ckpt some_dataset \ | ||
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | ||
--lr 1e-4 \ | ||
--batch_size 8 \ | ||
--epochs 3 \ | ||
--early_exit_loss_scale 1.0 \ | ||
--eval_freq 50 \ | ||
--output_dir ./checkpoints | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating here the command to train.py
. But can you also check that the command works? Like we can put a command that users can copy and paste to their command line (e.g., put the TopV2 dataset rather than some_dataset
## Fine Tune | |
To train any supported HuggingFace model with the LayerSkip approach: | |
```bash | |
torchrun finetune_layerskip.py \ | |
--ckpt facebook/llama2-7B \ | |
--ds_ckpt some_dataset \ | |
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | |
--lr 1e-4 \ | |
--batch_size 8 \ | |
--epochs 3 \ | |
--early_exit_loss_scale 1.0 \ | |
--eval_freq 50 \ | |
--output_dir ./checkpoints | |
``` | |
## Train | |
To train any supported HuggingFace model with the LayerSkip approach: | |
```bash | |
torchrun train.py \ | |
--ckpt meta-llama/Llama-2-7b-hf \ | |
--ds_ckpt some_dataset \ | |
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | |
--lr 1e-4 \ | |
--batch_size 8 \ | |
--epochs 3 \ | |
--early_exit_loss_scale 1.0 \ | |
--eval_freq 50 \ | |
--output_dir ./checkpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check and update.
This PR is aimed at adding a fine tuning script with LayerSkip.
@mostafaelhoushi would you like to review at the current state and let me know what you think of it, and need in the next iteration?
(Note: This is a WIP and does not support a lot of goodies required for training efficiently.)