Skip to content

Commit 7bcf5a6

Browse files
committed
Add part 2 of end-to-end tutorial: fine-tuning
This commit adds the QAT tutorial and a general structure for the fine-tuning tutorial, which all also include QLoRA and float8 quantized fine-tuning. It also connects the 3 tutorial parts (pre-training, fine-tuning, and serving) into one cohesive end-to-end flow with some visuals and text.
1 parent 6a8887f commit 7bcf5a6

File tree

9 files changed

+339
-19
lines changed

9 files changed

+339
-19
lines changed

docs/source/finetuning.rst

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
(Part 2) Fine-tuning with QAT, QLoRA, and float8
2+
------------------------------------------------
3+
4+
TorchAO provides an end-to-end pre-training, fine-tuning, and serving
5+
model optimization flow by leveraging our quantization and sparsity
6+
techniques integrated into our partner frameworks. This is part 2 of 3
7+
such tutorials showcasing this end-to-end flow, focusing on the
8+
fine-tuning step.
9+
10+
.. image:: ../static/e2e_flow_part2.png
11+
12+
Fine-tuning is an important step for adapting your pre-trained model
13+
to more domain-specific data. In this tutorial, we demonstrate 3 model
14+
optimization techniques that can be applied to your model during fine-tuning:
15+
16+
1. **Quantization-Aware Training (QAT)**, for adapting your model to
17+
quantization numerics during fine-tuning, with the goal of mitigating
18+
quantization degradations in your fine-tuned model when it is quantized
19+
eventually, e.g. in the serving step. Check out `our blog <https://pytorch.org/blog/quantization-aware-training/>`__
20+
and `README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__ for more details!
21+
22+
2. **Quantized Low-Rank Adaptation (QLoRA)**, for reducing the resource
23+
requirement of fine-tuning by introducing small, trainable low-rank
24+
matrices and freezing the original pre-trained checkpoint, a type of
25+
Parameter-Efficient Fine-Tuning (PEFT). Please refer to the `original
26+
paper <https://arxiv.org/pdf/2305.14314>`__ for more details.
27+
28+
3. **Float8 Quantized Fine-tuning**, for speeding up fine-tuning by
29+
dynamically quantizing high precision weights and activations to float8,
30+
similar to `pre-training in float8 <pretraining.html>`__.
31+
32+
33+
Quantization-Aware Training (QAT)
34+
##################################
35+
36+
The goal of Quantization-Aware Training is to adapt the model to
37+
quantization numerics during training or fine-tuning, so as to mitigate
38+
the inevitable quantization degradation when the model is actually
39+
quantized eventually, presumably during the serving step after fine-tuning.
40+
TorchAO's QAT support has been used successfully for the recent release of
41+
the `Llama-3.2 quantized 1B/3B <https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/>`__
42+
and the `LlamaGuard-3-8B <https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard3/8B/MODEL_CARD.md>`__ models to improve the quality of the quantized models.
43+
44+
TorchAO's QAT support involves two separate steps: prepare and convert.
45+
The prepare step "fake" quantizes activations and/or weights during
46+
training, which means, the high precision values (e.g. bf16) are mapped
47+
to their corresponding quantized values *without* actually casting them
48+
to the target lower precision dtype (e.g. int4). The convert step,
49+
applied after training, replaces "fake" quantization operations in the
50+
model with "real" quantization that does perform the dtype casting:
51+
52+
.. image:: ../../torchao/quantization/qat/images/qat_diagram.png
53+
54+
There are multiple options for using TorchAO's QAT for fine-tuning:
55+
56+
1. Use our integration with `TorchTune <https://github.com/pytorch/torchtune>`__
57+
2. Use our integration with `Axolotl <https://github.com/axolotl-ai-cloud/axolotl>`__
58+
3. Directly use our QAT APIs with your own training loop
59+
60+
61+
Option 1: TorchTune QAT Integration
62+
===================================
63+
64+
TorchAO's QAT support is integrated into TorchTune's distributed fine-tuning recipe.
65+
Instead of the following command, which applies full distributed fine-tuning without QAT:
66+
67+
.. code::
68+
69+
# Regular fine-tuning without QAT
70+
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full batch_size=16
71+
72+
Users can run the following equivalent command instead. Note that specifying the quantizer
73+
is optional:
74+
75+
.. code::
76+
77+
# Fine-tuning with QAT, by default:
78+
# activations are fake quantized to asymmetric per token int8
79+
# weights are fake quantized to symmetric per group int4
80+
# configurable through "quantizer._component_" in the command
81+
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16
82+
83+
After fine-tuning, users can quantize and evaluate the resulting model as follows.
84+
This is the same whether or not QAT was used during the fine-tuning process:
85+
86+
.. code::
87+
88+
# Quantize model weights to int4
89+
tune run quantize --config quantization \
90+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
91+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
92+
'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
93+
checkpointer.model_type=LLAMA3 \
94+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
95+
quantizer.groupsize=32
96+
97+
# Evaluate the int4 model on hellaswag and wikitext
98+
tune run eleuther_eval --config eleuther_evaluation \
99+
batch_size=1 \
100+
'tasks=[hellaswag, wikitext]' \
101+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
102+
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
103+
'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
104+
checkpointer.model_type=LLAMA3 \
105+
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
106+
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
107+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
108+
quantizer.groupsize=32
109+
110+
This should print the following after fine-tuning:
111+
112+
.. code::
113+
114+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
115+
|---------|------:|------|------|--------|---|-----:|---|-----:|
116+
|hellaswag| 1|none |None |acc |↑ |0.5021|± |0.0050|
117+
| | |none |None |acc_norm|↑ |0.6797|± |0.0047|
118+
119+
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
120+
|--------|------:|------|------|---------------|---|------:|---|------|
121+
|wikitext| 2|none |None |bits_per_byte |↓ | 0.6965|± | N/A|
122+
| | |none |None |byte_perplexity|↓ | 1.6206|± | N/A|
123+
| | |none |None |word_perplexity|↓ |13.2199|± | N/A|
124+
125+
You can compare these values with and without QAT to see how much QAT helped mitigate quantization degradation!
126+
For example, when fine-tuning Llama-3.2-3B on the
127+
`OpenAssistant Conversations (OASST1) <https://huggingface.co/datasets/OpenAssistant/oasst1>`__
128+
dataset, we find that the quantized model achieved 3.4% higher accuracy
129+
with QAT than without, recovering 69.8% of the overall accuracy degradation
130+
from quantization:
131+
132+
.. image:: ../static/qat_eval.png
133+
134+
In addition to vanilla QAT as in the above example, TorchAO's QAT can also be composed with LoRA to yield a `1.89x training speedup <https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700>`__. This is implemented in TorchTune's `QAT + LoRA fine-tuning recipe <https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py>`__, which can be run using the following command:
135+
136+
.. code::
137+
138+
# Fine-tuning with QAT + LoRA
139+
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora batch_size=16
140+
141+
For more details about how QAT is set up in TorchTune, please refer to `this tutorial <https://docs.pytorch.org/torchtune/main/tutorials/qat_finetune.html>`__.
142+
143+
144+
Option 2: Axolotl QAT Integration
145+
=================================
146+
147+
Axolotl also recently added a QAT fine-tuning recipe that leverages TorchAO's QAT support.
148+
To get started, try fine-tuning Llama-3.2-3B with QAT using the following command:
149+
150+
.. code::
151+
152+
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
153+
# once training is complete, perform the quantization step
154+
155+
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
156+
# you should now have a quantized model saved in ./outputs/qat_out/quatized
157+
158+
Please refer to the `Axolotl QAT documentation <https://docs.axolotl.ai/docs/qat.html>`__ for full details.
159+
160+
161+
Option 3: TorchAO QAT API
162+
=========================
163+
164+
If you prefer to use a different training framework or your own custom training loop,
165+
you can call TorchAO's QAT APIs directly to transform the model before fine-tuning.
166+
These APIs are what the TorchTune and Axolotl QAT integrations call under the hood.
167+
168+
In this example, we will fine-tune a mini version of Llama3 on a single GPU:
169+
170+
.. code:: py
171+
172+
import torch
173+
from torchtune.models.llama3 import llama3
174+
175+
# Set up a smaller version of llama3 to fit in a single A100 GPU
176+
# For smaller GPUs, adjust the model attributes accordingly
177+
def get_model():
178+
return llama3(
179+
vocab_size=4096,
180+
num_layers=16,
181+
num_heads=16,
182+
num_kv_heads=4,
183+
embed_dim=2048,
184+
max_seq_len=2048,
185+
).cuda()
186+
187+
# Example training loop
188+
def train_loop(m: torch.nn.Module):
189+
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
190+
loss_fn = torch.nn.CrossEntropyLoss()
191+
for i in range(10):
192+
example = torch.randint(0, 4096, (2, 16)).cuda()
193+
target = torch.randn((2, 16, 4096)).cuda()
194+
output = m(example)
195+
loss = loss_fn(output, target)
196+
loss.backward()
197+
optimizer.step()
198+
optimizer.zero_grad()
199+
200+
Next, run the prepare step, which fake quantizes the model. In this example,
201+
we use int8 per token dynamic activations and int4 symmetric per group weights
202+
as our quantization scheme. Note that although we are targeting lower integer
203+
precisions, training still performs arithmetic in higher float precision (float32)
204+
because we are not actually casting the fake quantized values.
205+
206+
.. code:: py
207+
208+
from torchao.quantization import (
209+
quantize_,
210+
)
211+
from torchao.quantization.qat import (
212+
FakeQuantizeConfig,
213+
IntXQuantizationAwareTrainingConfig,
214+
)
215+
model = get_model()
216+
217+
# prepare: insert fake quantization ops
218+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
219+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
220+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
221+
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
222+
quantize_(model, qat_config)
223+
224+
# fine-tune
225+
train_loop(model)
226+
227+
After fine-tuning, we end up with a model in the original high precision.
228+
This fine-tuned model has the exact same structure as the original model.
229+
The only difference is the QAT fine-tuned model has weights that are more
230+
attuned to quantization, which will be beneficial later during inference.
231+
The next step is to actually quantize the model:
232+
233+
.. code:: py
234+
235+
from torchao.quantization import (
236+
Int8DynamicActivationInt4WeightConfig,
237+
)
238+
from torchao.quantization.qat import (
239+
FromIntXQuantizationAwareTrainingConfig,
240+
)
241+
242+
# convert: transform fake quantization ops into actual quantized ops
243+
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
244+
# quantized activation and weight tensor subclasses
245+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
246+
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
247+
248+
Now our model is ready for serving, and will typically have higher quantized
249+
accuracy than if we did not apply the prepare step (fake quantization) during
250+
fine-tuning.
251+
252+
For full details of using TorchAO's QAT API, please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__.
253+
254+
.. raw:: html
255+
256+
<details>
257+
<summary><a>Alternative Legacy API</a></summary>
258+
259+
The above `quantize_` API is the recommended flow for using TorchAO QAT.
260+
We also offer an alternative legacy "quantizer" API for specific quantization
261+
schemes, but these are not customizable unlike the above example.
262+
263+
.. code::
264+
265+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
266+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
267+
268+
# prepare: insert fake quantization ops
269+
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
270+
model = qat_quantizer.prepare(model)
271+
272+
# train
273+
train_loop(model)
274+
275+
# convert: transform fake quantization ops into actual quantized ops
276+
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
277+
model = qat_quantizer.convert(model)
278+
279+
.. raw:: html
280+
281+
</details>
282+
283+
284+
Quantized Low-Rank Adaptation (QLoRA)
285+
#####################################
286+
287+
(Coming soon!)
288+
289+
290+
Float8 Quantized Fine-tuning
291+
############################
292+
293+
(Coming soon!)

docs/source/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ for an overall introduction to the library and recent highlight and updates.
3737
:maxdepth: 1
3838
:caption: Eager Quantization Tutorials
3939

40+
pretraining
41+
finetuning
42+
serving
43+
torchao_vllm_integration
4044
serialization
45+
static_quantization
4146
subclass_basic
4247
subclass_advanced
43-
static_quantization
44-
pretraining
45-
torchao_vllm_integration
4648

4749
.. toctree::
4850
:glob:

0 commit comments

Comments
 (0)