Skip to content

Commit 7da9916

Browse files
committed
Add part 2 of end-to-end tutorial: fine-tuning
This commit adds the QAT tutorial and a general structure for the fine-tuning tutorial, which all also include QLoRA and float8 quantized fine-tuning. It also connects the 3 tutorial parts (pre-training, fine-tuning, and serving) into one cohesive end-to-end flow with some visuals and text.
1 parent e29b9bd commit 7da9916

File tree

8 files changed

+336
-18
lines changed

8 files changed

+336
-18
lines changed

docs/source/finetuning.rst

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
(Part 2) Fine-tuning with QAT, QLoRA, and float8
2+
------------------------------------------------
3+
4+
TorchAO provides an end-to-end pre-training, fine-tuning, and serving
5+
model optimization flow by leveraging our quantization and sparsity
6+
techniques integrated into our partner frameworks. This is part 2 of 3
7+
such tutorials showcasing this end-to-end flow, focusing on the
8+
fine-tuning step.
9+
10+
.. image:: ../static/e2e_flow_part2.png
11+
12+
Fine-tuning is an important step for adapting your pre-trained model
13+
to more domain-specific data. In this tutorial, we demonstrate 3 model
14+
optimization techniques that can be applied to your model during fine-tuning:
15+
16+
1. **Quantization-Aware Training (QAT)**, for adapting your model to
17+
quantization numerics during fine-tuning, with the goal of mitigating
18+
quantization degradations in your fine-tuned model when it is quantized
19+
eventually, e.g. in the serving step. Check out `our blog <https://pytorch.org/blog/quantization-aware-training/>`__
20+
for more details!
21+
22+
2. **Quantized Low-Rank Adaptation (QLoRA)**, for reducing the resource
23+
requirement of fine-tuning by introducing small, trainable low-rank
24+
matrices and freezing the original pre-trained checkpoint, a type of
25+
Parameter-Efficient Fine-Tuning (PEFT). Please refer to the `original
26+
paper <https://arxiv.org/pdf/2305.14314>`__ for more details.
27+
28+
3. **Float8 Quantized Fine-tuning**, for speeding up fine-tuning by
29+
dynamically quantizing high precision weights and activations to float8,
30+
similar to `pre-training in float8 <pretraining.html>`__.
31+
32+
33+
Quantization-Aware Training (QAT)
34+
##################################
35+
36+
The goal of Quantization-Aware Training is to adapt the model to
37+
quantization numerics during training or fine-tuning, so as to mitigate
38+
the inevitable quantization degradation when the model is actually
39+
quantized eventually, presumably during the serving step after fine-tuning.
40+
TorchAO's QAT support has been used successfully for the recent release of
41+
the `Llama-3.2 quantized 1B/3B <https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/>`__
42+
and the `LlamaGuard-3-8B <https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard3/8B/MODEL_CARD.md>`__ models to improve the quality of the quantized models.
43+
44+
TorchAO's QAT support involves two separate steps: prepare and convert.
45+
The prepare step "fake" quantizes activations and/or weights during
46+
training, which means, the high precision values (e.g. bf16) are mapped
47+
to their corresponding quantized values *without* actually casting them
48+
to the target lower precision dtype (e.g. int4). The convert step,
49+
applied after training, replaces "fake" quantization operations in the
50+
model with "real" quantization that does perform the dtype casting:
51+
52+
.. image:: ../../torchao/quantization/qat/images/qat_diagram.png
53+
54+
There are multiple options for using TorchAO's QAT for fine-tuning:
55+
1. Directly use our QAT APIs with your own training loop
56+
2. Use our integration with `TorchTune <https://github.com/pytorch/torchtune>`__
57+
3. Use our integratino with `Axolotl <https://github.com/axolotl-ai-cloud/axolotl>`__
58+
59+
60+
Option 1: TorchAO QAT API
61+
=========================
62+
63+
First, set up the model for fine-tuning on a single GPU:
64+
65+
.. code:: py
66+
67+
import torch
68+
from torchtune.models.llama3 import llama3
69+
70+
# Set up smaller version of llama3 to fit in a single GPU
71+
def get_model():
72+
return llama3(
73+
vocab_size=4096,
74+
num_layers=16,
75+
num_heads=16,
76+
num_kv_heads=4,
77+
embed_dim=2048,
78+
max_seq_len=2048,
79+
).cuda()
80+
81+
# Example training loop
82+
def train_loop(m: torch.nn.Module):
83+
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
84+
loss_fn = torch.nn.CrossEntropyLoss()
85+
for i in range(10):
86+
example = torch.randint(0, 4096, (2, 16)).cuda()
87+
target = torch.randn((2, 16, 4096)).cuda()
88+
output = m(example)
89+
loss = loss_fn(output, target)
90+
loss.backward()
91+
optimizer.step()
92+
optimizer.zero_grad()
93+
94+
Next, run the prepare step, which fake quantizes the model. In this example,
95+
we use int8 per token dynamic activations and int4 symmetric per group weights
96+
as our quantization scheme. Note that although we are targeting lower integer
97+
precisions, training still performs arithmetic in higher float precision (float32)
98+
because we are not actually casting the fake quantized values.
99+
100+
.. code:: py
101+
102+
from torchao.quantization import (
103+
quantize_,
104+
)
105+
from torchao.quantization.qat import (
106+
FakeQuantizeConfig,
107+
IntXQuantizationAwareTrainingConfig,
108+
)
109+
model = get_model()
110+
111+
# prepare: insert fake quantization ops
112+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
113+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
114+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
115+
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
116+
quantize_(model, qat_config)
117+
118+
# fine-tune
119+
train_loop(model)
120+
121+
After fine-tuning, we end up with a model in the original high precision.
122+
This fine-tuned model has the exact same structure as the original model.
123+
The only difference is the QAT fine-tuned model has weights that are more
124+
attuned to quantization, which will be beneficial later during inference.
125+
The next step is to actually quantize the model:
126+
127+
.. code:: py
128+
129+
from torchao.quantization import (
130+
Int8DynamicActivationInt4WeightConfig,
131+
)
132+
from torchao.quantization.qat import (
133+
FromIntXQuantizationAwareTrainingConfig,
134+
)
135+
136+
# convert: transform fake quantization ops into actual quantized ops
137+
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
138+
# quantized activation and weight tensor subclasses
139+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
140+
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
141+
142+
Now our model is ready for serving, and will typically have higher quantized
143+
accuracy than if we did not apply the prepare step (fake quantization) during
144+
fine-tuning. For example, when fine-tuning Llama-3.2-3B on the
145+
`OpenAssistant Conversations (OASST1) <https://huggingface.co/datasets/OpenAssistant/oasst1>`__
146+
dataset, we find that the quantized model achieved 3.4% higher accuracy
147+
with QAT than without, recovering 69.8% of the overall accuracy degradation
148+
from quantization:
149+
150+
.. image:: ../static/qat_eval.png
151+
152+
For full details of using TorchAO's QAT API, please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__.
153+
154+
.. raw:: html
155+
156+
<details>
157+
<summary><a>Alternative Legacy API</a></summary>
158+
159+
The above `quantize_` API is the recommended flow for using TorchAO QAT.
160+
We also offer an alternative legacy "quantizer" API for specific quantization
161+
schemes, but these are not customizable unlike the above example.
162+
163+
.. code::
164+
165+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
166+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
167+
168+
# prepare: insert fake quantization ops
169+
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
170+
model = qat_quantizer.prepare(model)
171+
172+
# train
173+
train_loop(model)
174+
175+
# convert: transform fake quantization ops into actual quantized ops
176+
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
177+
model = qat_quantizer.convert(model)
178+
179+
.. raw:: html
180+
181+
</details>
182+
183+
184+
Option 2: TorchTune QAT Integration
185+
===================================
186+
187+
TorchAO's QAT support is integrated into TorchTune's distributed fine-tuning recipe.
188+
Instead of the following command, which applies full distributed fine-tuning without QAT:
189+
190+
.. code::
191+
192+
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full \
193+
epochs=1 \
194+
batch_size=16 \
195+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset
196+
197+
Users can run the following equivalent command instead. Note that specifying the quantizer
198+
is optional:
199+
200+
.. code::
201+
202+
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \
203+
epochs=1 \
204+
batch_size=16 \
205+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
206+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
207+
quantizer.groupsize=32
208+
209+
After fine-tuning, users can quantize and evaluate the resulting model as follows.
210+
This is the same whether or not QAT was used during the fine-tuning process:
211+
212+
.. code::
213+
214+
tune run quantize --config quantization \
215+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
216+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
217+
'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
218+
checkpointer.model_type=LLAMA3 \
219+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
220+
quantizer.groupsize=32
221+
222+
tune run eleuther_eval --config eleuther_evaluation \
223+
batch_size=1 \
224+
'tasks=[hellaswag, wikitext]' \
225+
model._component_=torchtune.models.llama3_2.llama3_2_3b \
226+
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
227+
'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
228+
checkpointer.model_type=LLAMA3 \
229+
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
230+
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
231+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
232+
quantizer.groupsize=32
233+
234+
This should print the following after fine-tuning:
235+
236+
.. code::
237+
238+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
239+
|---------|------:|------|------|--------|---|-----:|---|-----:|
240+
|hellaswag| 1|none |None |acc |↑ |0.5021|± |0.0050|
241+
| | |none |None |acc_norm|↑ |0.6797|± |0.0047|
242+
243+
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
244+
|--------|------:|------|------|---------------|---|------:|---|------|
245+
|wikitext| 2|none |None |bits_per_byte |↓ | 0.6965|± | N/A|
246+
| | |none |None |byte_perplexity|↓ | 1.6206|± | N/A|
247+
| | |none |None |word_perplexity|↓ |13.2199|± | N/A|
248+
249+
You can compare these values with and without QAT to see how much QAT helped mitigate quantization degradation!
250+
251+
In addition to vanilla QAT as in the above example, TorchAO's QAT can also be composed with LoRA to yield a `1.89x training speedup <https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700>`__. This is implemented in TorchTune's `QAT + LoRA fine-tuning recipe <https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py>`__, which can be run using the following command:
252+
253+
.. code::
254+
255+
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora \
256+
epochs=1 \
257+
batch_size=16 \
258+
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
259+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
260+
quantizer.groupsize=32
261+
262+
For more details about how QAT is set up in TorchTune, please refer to `this tutorial <https://docs.pytorch.org/torchtune/main/tutorials/qat_finetune.html>`__.
263+
264+
265+
Option 3: Axolotl QAT Integration
266+
=================================
267+
268+
Axolotl also recently added a QAT fine-tuning recipe that leverages TorchAO's QAT support.
269+
To get started, try fine-tuning Llama-3.2-3B with QAT using the following command:
270+
271+
.. code::
272+
273+
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
274+
# once training is complete, perform the quantization step
275+
276+
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
277+
# you should now have a quantized model saved in ./outputs/qat_out/quatized
278+
279+
Please refer to the `Axolotl QAT documentation <https://docs.axolotl.ai/docs/qat.html>`__ for full details.
280+
281+
282+
Quantized Low-Rank Adaptation (QLoRA)
283+
#####################################
284+
285+
(Coming soon!)
286+
287+
288+
Float8 Quantized Fine-tuning
289+
############################
290+
291+
(Coming soon!)

docs/source/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ for an overall introduction to the library and recent highlight and updates.
3737
:maxdepth: 1
3838
:caption: Eager Quantization Tutorials
3939

40+
pretraining
41+
finetuning
42+
serving
43+
torchao_vllm_integration
4044
serialization
45+
static_quantization
4146
subclass_basic
4247
subclass_advanced
43-
static_quantization
44-
pretraining
45-
torchao_vllm_integration
4648

4749
.. toctree::
4850
:glob:

0 commit comments

Comments
 (0)