Skip to content

Commit 744d1f2

Browse files
committed
Add part 2 of end-to-end tutorial: fine-tuning
This commit adds the QAT tutorial and a general structure for the fine-tuning tutorial, which all also include QLoRA and float8 quantized fine-tuning. It also connects the 3 tutorial parts (pre-training, fine-tuning, and serving) into one cohesive end-to-end flow with some visuals and text.
1 parent 63a91d7 commit 744d1f2

File tree

8 files changed

+334
-18
lines changed

8 files changed

+334
-18
lines changed

docs/source/finetuning.rst

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
(Part 2) Fine-tuning with QAT, QLoRA, and float8
2+
------------------------------------------------
3+
4+
TorchAO provides an end-to-end pre-training, fine-tuning, and serving
5+
model optimization flow by leveraging our quantization and sparsity
6+
techniques integrated into our partner frameworks. This is part 2 of 3
7+
such tutorials showcasing this end-to-end flow, focusing on the
8+
fine-tuning step.
9+
10+
.. image:: ../static/e2e_flow_part2.png
11+
12+
Fine-tuning is an important step for adapting your pre-trained model
13+
to more domain-specific data. In this tutorial, we demonstrate 3 model
14+
optimization techniques that can be applied to your model during fine-tuning:
15+
16+
1. **Quantization-Aware Training (QAT)**, for adapting your model to
17+
quantization numerics during fine-tuning, with the goal of mitigating
18+
quantization degradations in your fine-tuned model when it is quantized
19+
eventually, e.g. in the serving step.
20+
21+
2. **Quantized Low-Rank Adaptation (QLoRA)**, for reducing the resource
22+
requirement of fine-tuning by introducing small, trainable low-rank
23+
matrices and freezing the original pre-trained checkpoint, a type of
24+
Parameter-Efficient Fine-Tuning (PEFT).
25+
26+
3. **Float8 Quantized Fine-tuning**, for speeding up fine-tuning by
27+
dynamically quantizing high precision weights and activations to float8,
28+
similar to `pre-training in float8 <pretraining.html>`__.
29+
30+
31+
Quantization-Aware Training (QAT)
32+
##################################
33+
34+
The goal of Quantization-Aware Training is to adapt the model to
35+
quantization numerics during training or fine-tuning, so as to mitigate
36+
the inevitable quantization degradation when the model is actually
37+
quantized eventually, presumably during the serving step after fine-tuning.
38+
TorchAO's QAT support has been used successfully for the recent release of
39+
the `Llama-3.2 quantized 1B/3B <https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/>`__
40+
and the `LlamaGuard-3-8B <https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard3/8B/MODEL_CARD.md>`__ models to improve the quality of the quantized models.
41+
42+
TorchAO's QAT support involves two separate steps: prepare and convert.
43+
The prepare step "fake" quantizes activations and/or weights during
44+
training, which means, the high precision values (e.g. bf16) are mapped
45+
to their corresponding quantized values *without* actually casting them
46+
to the target lower precision dtype (e.g. int4). The convert step,
47+
applied after training, replaces "fake" quantization operations in the
48+
model with "real" quantization that does perform the dtype casting:
49+
50+
.. image:: ../../torchao/quantization/qat/images/qat_diagram.png
51+
52+
There are multiple options for using TorchAO's QAT for fine-tuning:
53+
1. Directly use our QAT APIs with your own training loop
54+
2. Use our integration with `TorchTune <https://github.com/pytorch/torchtune>`__
55+
3. Use our integratino with `Axolotl <https://github.com/axolotl-ai-cloud/axolotl>`__
56+
57+
58+
Option 1: TorchAO QAT API
59+
=========================
60+
61+
First, set up the model for fine-tuning on a single GPU:
62+
63+
.. code:: py
64+
65+
import torch
66+
from torchtune.models.llama3 import llama3
67+
68+
# Set up smaller version of llama3 to fit in a single GPU
69+
def get_model():
70+
return llama3(
71+
vocab_size=4096,
72+
num_layers=16,
73+
num_heads=16,
74+
num_kv_heads=4,
75+
embed_dim=2048,
76+
max_seq_len=2048,
77+
).cuda()
78+
79+
# Example training loop
80+
def train_loop(m: torch.nn.Module):
81+
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
82+
loss_fn = torch.nn.CrossEntropyLoss()
83+
for i in range(10):
84+
example = torch.randint(0, 4096, (2, 16)).cuda()
85+
target = torch.randn((2, 16, 4096)).cuda()
86+
output = m(example)
87+
loss = loss_fn(output, target)
88+
loss.backward()
89+
optimizer.step()
90+
optimizer.zero_grad()
91+
92+
Next, run the prepare step, which fake quantizes the model. In this example,
93+
we use int8 per token dynamic activations and int4 symmetric per group weights
94+
as our quantization scheme. Note that although we are targeting lower integer
95+
precisions, training still performs arithmetic in higher float precision (float32)
96+
because we are not actually casting the fake quantized values.
97+
98+
.. code:: py
99+
100+
from torchao.quantization import (
101+
quantize_,
102+
)
103+
from torchao.quantization.qat import (
104+
FakeQuantizeConfig,
105+
IntXQuantizationAwareTrainingConfig,
106+
)
107+
model = get_model()
108+
109+
# prepare: insert fake quantization ops
110+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
111+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
112+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
113+
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
114+
quantize_(model, qat_config)
115+
116+
# fine-tune
117+
train_loop(model)
118+
119+
After fine-tuning, we end up with a model in the original high precision.
120+
This fine-tuned model has the exact same structure as the original model.
121+
The only difference is the QAT fine-tuned model has weights that are more
122+
attuned to quantization, which will be beneficial later during inference.
123+
The next step is to actually quantize the model:
124+
125+
.. code:: py
126+
127+
from torchao.quantization import (
128+
Int8DynamicActivationInt4WeightConfig,
129+
)
130+
from torchao.quantization.qat import (
131+
FromIntXQuantizationAwareTrainingConfig,
132+
)
133+
134+
# convert: transform fake quantization ops into actual quantized ops
135+
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
136+
# quantized activation and weight tensor subclasses
137+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
138+
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
139+
140+
Now our model is ready for serving, and will typically have higher quantized
141+
accuracy than if we did not apply the prepare step (fake quantization) during
142+
fine-tuning. For example, when fine-tuning Llama-3.2-3B on the
143+
`OpenAssistant Conversations (OASST1) <https://huggingface.co/datasets/OpenAssistant/oasst1>`__
144+
dataset, we find that the quantized model achieved 3.4% higher accuracy
145+
with QAT than without, recovering 69.8% of the overall accuracy degradation
146+
from quantization:
147+
148+
.. image:: ../static/qat_eval.png
149+
150+
For full details of using TorchAO's QAT API, please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__.
151+
152+
.. raw:: html
153+
154+
<details>
155+
<summary><a>Alternative Legacy API</a></summary>
156+
157+
The above `quantize_` API is the recommended flow for using TorchAO QAT.
158+
We also offer an alternative legacy "quantizer" API for specific quantization
159+
schemes, but these are not customizable unlike the above example.
160+
161+
.. code::
162+
163+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
164+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
165+
166+
# prepare: insert fake quantization ops
167+
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
168+
model = qat_quantizer.prepare(model)
169+
170+
# train
171+
train_loop(model)
172+
173+
# convert: transform fake quantization ops into actual quantized ops
174+
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
175+
model = qat_quantizer.convert(model)
176+
177+
.. raw:: html
178+
179+
</details>
180+
181+
182+
Option 2: TorchTune QAT Integration
183+
===================================
184+
185+
TorchAO's QAT support is integrated into TorchTune's distributed fine-tuning recipe.
186+
Instead of the following command, which applies full distributed fine-tuning without QAT:
187+
188+
.. code::
189+
190+
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full \
191+
epochs=1 \
192+
batch_size=16 \
193+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset
194+
195+
Users can run the following equivalent command instead. Note that specifying the quantizer
196+
is optional:
197+
198+
.. code::
199+
200+
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \
201+
epochs=1 \
202+
batch_size=16 \
203+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
204+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
205+
quantizer.groupsize=32
206+
207+
After fine-tuning, users can quantize and evaluate the resulting model as follows.
208+
This is the same whether or not QAT was used during the fine-tuning process:
209+
210+
.. code::
211+
212+
tune run quantize --config quantization \
213+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
214+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
215+
'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
216+
checkpointer.model_type=LLAMA3 \
217+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
218+
quantizer.groupsize=32
219+
220+
tune run eleuther_eval --config eleuther_evaluation \
221+
batch_size=1 \
222+
'tasks=[hellaswag, wikitext]' \
223+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
224+
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
225+
'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
226+
checkpointer.model_type=LLAMA3 \
227+
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
228+
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
229+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
230+
quantizer.groupsize=32
231+
232+
This should print the following after fine-tuning:
233+
234+
.. code::
235+
236+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
237+
|---------|------:|------|------|--------|---|-----:|---|-----:|
238+
|hellaswag| 1|none |None |acc |↑ |0.5021|± |0.0050|
239+
| | |none |None |acc_norm|↑ |0.6797|± |0.0047|
240+
241+
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
242+
|--------|------:|------|------|---------------|---|------:|---|------|
243+
|wikitext| 2|none |None |bits_per_byte |↓ | 0.6965|± | N/A|
244+
| | |none |None |byte_perplexity|↓ | 1.6206|± | N/A|
245+
| | |none |None |word_perplexity|↓ |13.2199|± | N/A|
246+
247+
You can compare these values with and without QAT to see how much QAT helped mitigate quantization degradation!
248+
249+
In addition to vanilla QAT as in the above example, TorchAO's QAT can also be composed with LoRA to yield a `1.89x training speedup <https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700>`__. This is implemented in TorchTune's `QAT + LoRA fine-tuning recipe <https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py>`__, which can be run using the following command:
250+
251+
.. code::
252+
253+
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora \
254+
epochs=1 \
255+
batch_size=16 \
256+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
257+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
258+
quantizer.groupsize=32
259+
260+
For more details about how QAT is set up in TorchTune, please refer to `this tutorial <https://docs.pytorch.org/torchtune/main/tutorials/qat_finetune.html>`__.
261+
262+
263+
Option 3: Axolotl QAT Integration
264+
=================================
265+
266+
Axolotl also recently added a QAT fine-tuning recipe that leverages TorchAO's QAT support.
267+
To get started, try fine-tuning Llama-3.2-3B with QAT using the following command:
268+
269+
.. code::
270+
271+
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
272+
# once training is complete, perform the quantization step
273+
274+
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
275+
# you should now have a quantized model saved in ./outputs/qat_out/quatized
276+
277+
Please refer to the `Axolotl QAT documentation <https://docs.axolotl.ai/docs/qat.html>`__ for full details.
278+
279+
280+
Quantized Low-Rank Adaptation (QLoRA)
281+
#####################################
282+
283+
(Coming soon!)
284+
285+
286+
Float8 Quantized Fine-tuning
287+
############################
288+
289+
(Coming soon!)

docs/source/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ for an overall introduction to the library and recent highlight and updates.
3737
:maxdepth: 1
3838
:caption: Tutorials
3939

40+
pretraining
41+
finetuning
42+
serving
43+
torchao_vllm_integration
4044
serialization
45+
static_quantization
4146
subclass_basic
4247
subclass_advanced
43-
static_quantization
44-
pretraining
45-
torchao_vllm_integration

docs/source/pretraining.rst

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
1-
Pretraining with float8
1+
(Part 1) Pre-training with float8
22
---------------------------------
33

4-
Pretraining with float8 using torchao can provide `up to 1.5x speedups <https://pytorch.org/blog/training-using-float8-fsdp2/>`__ on 512 GPU clusters,
4+
TorchAO provides an end-to-end pre-training, fine-tuning, and serving
5+
model optimization flow by leveraging our quantization and sparsity
6+
techniques integrated into our partner frameworks. This is part 1 of 3
7+
such tutorials showcasing this end-to-end flow, focusing on the
8+
pre-training step.
9+
10+
.. image:: ../static/e2e_flow_part1.png
11+
12+
Pre-training with float8 using torchao can provide `up to 1.5x speedups <https://pytorch.org/blog/training-using-float8-fsdp2/>`__ on 512 GPU clusters,
513
and up to `1.34-1.43x speedups <https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/>`__ on 2K H200 clusters with the latest `torchao.float8` rowwise recipe.
614

7-
In this tutorial, we will show 2 ways to use the **torchao.float8** recipes for pretraining:
15+
In this tutorial, we will show 2 ways to use the **torchao.float8** recipes for pre-training:
816

9-
1. :ref:`Pretraining with torchtitan`, the offical PyTorch pretraining framework with native torchao integration.
10-
2. :ref:`Pretraining with torchao directly`, to integrate torchao's float8 training recipes into your own pretraining code.
17+
1. :ref:`Pre-training with torchtitan`, the offical PyTorch pre-training framework with native torchao integration.
18+
2. :ref:`Pre-training with torchao directly`, to integrate torchao's float8 training recipes into your own pre-training code.
1119

1220

13-
Pretraining with torchtitan
21+
Pre-training with torchtitan
1422
###########################
1523

16-
In this tutorial we'll pretrain Llama3 8b using torchtitan with torchao's float8 training recipes: rowwise scaling and tensorwise scaling.
24+
In this tutorial we'll pre-train Llama3-8B using torchtitan with torchao's float8 training recipes: rowwise scaling and tensorwise scaling.
1725

18-
`Torchtitan <https://github.com/pytorch/torchtitan/>`__ is PyTorch's official pretraining framework that is natively integrated with torchao, and supports
26+
`Torchtitan <https://github.com/pytorch/torchtitan/>`__ is PyTorch's official pre-training framework that is natively integrated with torchao, and supports
1927
several popular flagship models with common forms of parallelism, float8 training, distributed checkpointing and more.
2028
See the torchtitan `docs <https://github.com/pytorch/torchtitan>`__ for additional details.
2129

@@ -29,12 +37,12 @@ Prerequisites
2937
2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__.
3038
3. `Install torchtitan <https://github.com/pytorch/torchtitan/tree/main?tab=readme-ov-file#installation>`__, including the "downloading a tokenizer" step.
3139

32-
You're now ready to start a pretraining job using one of the recipes below!
40+
You're now ready to start a pre-training job using one of the recipes below!
3341

3442
Rowwise scaling
3543
===============
3644

37-
Run the following command from torchtitan root directory to launch a Llama3 8b training job on 8 GPUs with float8 rowwise training:
45+
Run the following command from torchtitan root directory to launch a Llama3-8B training job on 8 GPUs with float8 rowwise training:
3846

3947
.. code:: console
4048
@@ -104,10 +112,10 @@ Picking a recipe
104112
The higher throughput of tensorwise scaling comes at the cost of slightly higher quantization error (i.e., reduced numerical integrity vs bfloat16) compared to rowwise scaling.
105113
This is because rowwise scaling using a more granular scaling factor (per row, instead of per tensor), which limits the impact of outliers that can cause underflow during scaling.
106114

107-
Below you can see the loss curves comparing bfloat16, float8 tensorwise, and float8 rowwise training for training Llama3 8b on 8xH100 GPUs:
115+
Below you can see the loss curves comparing bfloat16, float8 tensorwise, and float8 rowwise training for training Llama3-8B on 8xH100 GPUs:
108116

109117
.. image:: ../static/fp8-loss-curves.png
110-
:alt: Loss curves for training Llama3 8b on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training.
118+
:alt: Loss curves for training Llama3-8B on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training.
111119

112120

113121
Important notes
@@ -117,12 +125,12 @@ Important notes
117125
* You must use :code:`--training.compile` to achieve high performance. torchao float8 training recipes are built natively on top of :code:`torch.compile`, so it will work out of the box!
118126

119127

120-
Pretraining with torchao directly
128+
Pre-training with torchao directly
121129
#################################
122130

123-
In this tutorial we'll pretrain a toy model using torchao APIs directly.
131+
In this tutorial we'll pre-train a toy model using torchao APIs directly.
124132

125-
You can use this workflow to integrate torchao into your own custom pretraining code directly.
133+
You can use this workflow to integrate torchao into your own custom pre-training code directly.
126134

127135
Prerequisites
128136
================
@@ -200,3 +208,8 @@ Below is a code snippet showing how to use it:
200208
'model_state_dict': m.state_dict(),
201209
'optimizer_state_dict': optimizer.state_dict(),
202210
}, 'checkpoint.pth')
211+
212+
213+
After pre-training your model, you can optionally fine-tune it to more domain-specific datasets
214+
and adapt it for eventual quantization during serving. In the `next part <finetuning.html>`__ of
215+
this tutorial, we will explore a few model optimization options during the fine-tuning step.

0 commit comments

Comments
 (0)