Skip to content

[WIP]Fine tune with LayerSkip #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

ariG23498
Copy link

This PR is aimed at adding a fine tuning script with LayerSkip.

@mostafaelhoushi would you like to review at the current state and let me know what you think of it, and need in the next iteration?

(Note: This is a WIP and does not support a lot of goodies required for training efficiently.)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 17, 2025
Copy link
Collaborator

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ariG23498 ! Looks great and I like the breakdown.
I have made some comments and suggestions.

fine-tune.py Outdated
super().__init__(*args, **kwargs)
self.model = model
self.num_layers = model.config.num_hidden_layers
self.early_exit_layer = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do it in this PR, but in the future, self.early_exit_layer could be a list of layers

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting idea. If I am understanding this correctly you mean we could have a list of indices for the layer we want to exit early self.early_exit_layers=[0, 4, 8]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That is actually what we referred to as "rotational curriculum" in the paper.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can create an issue and then do it in another PR if you want.

Copy link
Collaborator

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other things we might need:

  • add a line to save checkpoint to file (the path could be a CLI argument as well)
  • update README with example command to train a model

@ariG23498 ariG23498 marked this pull request as ready for review January 20, 2025 05:05
ariG23498 and others added 5 commits January 20, 2025 21:34
1. added ignore index in loss function
2. inputs and labels are same, as the offset is done inside the loss function of the causal model
Leads to:
```
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Orig Time: 0.8662526607513428
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.5897870063781738
For Layer: 1 Speedup: 1.4687550783305965
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.35465383529663086
For Layer: 2 Speedup: 2.4425300801464984
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.36078310012817383
For Layer: 3 Speedup: 2.4010344731878877
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.3151836395263672
For Layer: 4 Speedup: 2.748406173788329
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.3399159908294678
For Layer: 5 Speedup: 2.5484316246420207
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.3960268497467041
For Layer: 6 Speedup: 2.1873584109395403
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.38035011291503906
For Layer: 7 Speedup: 2.2775138782326128
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.4112529754638672
For Layer: 8 Speedup: 2.106374208658952
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.3891477584838867
For Layer: 9 Speedup: 2.226025055691568
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.41824865341186523
For Layer: 10 Speedup: 2.0711427369457924
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.43691492080688477
For Layer: 11 Speedup: 1.9826575369675328
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.44853711128234863
For Layer: 12 Speedup: 1.931284254885316
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.45146870613098145
For Layer: 13 Speedup: 1.9187435341310743
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.4784407615661621
For Layer: 14 Speedup: 1.8105745378292801
### Instruction: Set alarm for 6am every day
 ### Response:  [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ]
Layerskip Time: 0.4980161190032959
For Layer: 15 Speedup: 1.7394068739883695
```
README.md Outdated
Comment on lines 48 to 62
## Fine Tune

To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun finetune_layerskip.py \
--ckpt facebook/llama2-7B \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating here the command to train.py. But can you also check that the command works? Like we can put a command that users can copy and paste to their command line (e.g., put the TopV2 dataset rather than some_dataset

Suggested change
## Fine Tune
To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun finetune_layerskip.py \
--ckpt facebook/llama2-7B \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints
```
## Train
To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun train.py \
--ckpt meta-llama/Llama-2-7b-hf \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check and update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants