Open
Description
Tracker issue for adding LayerSkip to AO.
This is a training and inference optimization that is similar to layer-wise pruning. It's particularly interesting for LLM inference because it combines very cleanly with speculative decoding to provide up to a 1.86x speedup.
@mostafaelhoushi is interested in adding this to torchtune and is interested in upstreaming a subset of the code to ao. See here for more details. In particular, he's interested in doing this without having to alter the module definition.
This is attractive because this part of LayerSkip is not unique to LLMs and can be used for other models. (@mostafaelhoushi to fill out with relevant results).
What is being proposed:
for LayerSkip there is a training recipe and there is an inference recipe:
- training recipe:
- layer dropout: This is skipping layers sotchastically during training. this is what I think we can get into torch.ao, because it could benefit all types of models, transformers, CNNs, vision, text, etc. It can speedup training and might improve accuracy.
- early exit loss: This could also be added to torch.ao and it could help different modalities, but my require more time.
- inference recipe:
- speculative decoding: yes, this applies to LLM only, has been added to gpt-fast here: pytorch-labs/gpt-fast@main...LayerSkip