Skip to content

[RFC] Add LayerSkip to AO #633

Open
Open
@jcaip

Description

@jcaip

Tracker issue for adding LayerSkip to AO.

This is a training and inference optimization that is similar to layer-wise pruning. It's particularly interesting for LLM inference because it combines very cleanly with speculative decoding to provide up to a 1.86x speedup.

@mostafaelhoushi is interested in adding this to torchtune and is interested in upstreaming a subset of the code to ao. See here for more details. In particular, he's interested in doing this without having to alter the module definition.

This is attractive because this part of LayerSkip is not unique to LLMs and can be used for other models. (@mostafaelhoushi to fill out with relevant results).

What is being proposed:

for LayerSkip there is a training recipe and there is an inference recipe:

  • training recipe:
    • layer dropout: This is skipping layers sotchastically during training. this is what I think we can get into torch.ao, because it could benefit all types of models, transformers, CNNs, vision, text, etc. It can speedup training and might improve accuracy.
    • early exit loss: This could also be added to torch.ao and it could help different modalities, but my require more time.
  • inference recipe:

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions