|
| 1 | +(Part 2) Fine-tuning with QAT, QLoRA, and float8 |
| 2 | +------------------------------------------------ |
| 3 | + |
| 4 | +TorchAO provides an end-to-end pre-training, fine-tuning, and serving |
| 5 | +model optimization flow by leveraging our quantization and sparsity |
| 6 | +techniques integrated into our partner frameworks. This is part 2 of 3 |
| 7 | +such tutorials showcasing this end-to-end flow, focusing on the |
| 8 | +fine-tuning step. |
| 9 | + |
| 10 | +.. image:: ../static/e2e_flow_part2.png |
| 11 | + |
| 12 | +Fine-tuning is an important step for adapting your pre-trained model |
| 13 | +to more domain-specific data. In this tutorial, we demonstrate 3 model |
| 14 | +optimization techniques that can be applied to your model during fine-tuning: |
| 15 | + |
| 16 | +1. **Quantization-Aware Training (QAT)**, for adapting your model to |
| 17 | +quantization numerics during fine-tuning, with the goal of mitigating |
| 18 | +quantization degradations in your fine-tuned model when it is quantized |
| 19 | +eventually, e.g. in the serving step. Check out `our blog <https://pytorch.org/blog/quantization-aware-training/>`__ |
| 20 | +and `README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__ for more details! |
| 21 | + |
| 22 | +2. **Quantized Low-Rank Adaptation (QLoRA)**, for reducing the resource |
| 23 | +requirement of fine-tuning by introducing small, trainable low-rank |
| 24 | +matrices and freezing the original pre-trained checkpoint, a type of |
| 25 | +Parameter-Efficient Fine-Tuning (PEFT). Please refer to the `original |
| 26 | +paper <https://arxiv.org/pdf/2305.14314>`__ for more details. |
| 27 | + |
| 28 | +3. **Float8 Quantized Fine-tuning**, for speeding up fine-tuning by |
| 29 | +dynamically quantizing high precision weights and activations to float8, |
| 30 | +similar to `pre-training in float8 <pretraining.html>`__. |
| 31 | + |
| 32 | + |
| 33 | +Quantization-Aware Training (QAT) |
| 34 | +################################## |
| 35 | + |
| 36 | +The goal of Quantization-Aware Training is to adapt the model to |
| 37 | +quantization numerics during training or fine-tuning, so as to mitigate |
| 38 | +the inevitable quantization degradation when the model is actually |
| 39 | +quantized eventually, presumably during the serving step after fine-tuning. |
| 40 | +TorchAO's QAT support has been used successfully for the recent release of |
| 41 | +the `Llama-3.2 quantized 1B/3B <https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/>`__ |
| 42 | +and the `LlamaGuard-3-8B <https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard3/8B/MODEL_CARD.md>`__ models to improve the quality of the quantized models. |
| 43 | + |
| 44 | +TorchAO's QAT support involves two separate steps: prepare and convert. |
| 45 | +The prepare step "fake" quantizes activations and/or weights during |
| 46 | +training, which means, the high precision values (e.g. bf16) are mapped |
| 47 | +to their corresponding quantized values *without* actually casting them |
| 48 | +to the target lower precision dtype (e.g. int4). The convert step, |
| 49 | +applied after training, replaces "fake" quantization operations in the |
| 50 | +model with "real" quantization that does perform the dtype casting: |
| 51 | + |
| 52 | +.. image:: ../../torchao/quantization/qat/images/qat_diagram.png |
| 53 | + |
| 54 | +There are multiple options for using TorchAO's QAT for fine-tuning: |
| 55 | + |
| 56 | +1. Use our integration with `TorchTune <https://github.com/pytorch/torchtune>`__ |
| 57 | +2. Use our integration with `Axolotl <https://github.com/axolotl-ai-cloud/axolotl>`__ |
| 58 | +3. Directly use our QAT APIs with your own training loop |
| 59 | + |
| 60 | + |
| 61 | +Option 1: TorchTune QAT Integration |
| 62 | +=================================== |
| 63 | + |
| 64 | +TorchAO's QAT support is integrated into TorchTune's distributed fine-tuning recipe. |
| 65 | +Instead of the following command, which applies full distributed fine-tuning without QAT: |
| 66 | + |
| 67 | +.. code:: |
| 68 | +
|
| 69 | + # Regular fine-tuning without QAT |
| 70 | + tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full batch_size=16 |
| 71 | +
|
| 72 | +Users can run the following equivalent command instead. Note that specifying the quantizer |
| 73 | +is optional: |
| 74 | + |
| 75 | +.. code:: |
| 76 | +
|
| 77 | + # Fine-tuning with QAT, by default: |
| 78 | + # activations are fake quantized to asymmetric per token int8 |
| 79 | + # weights are fake quantized to symmetric per group int4 |
| 80 | + # configurable through "quantizer._component_" in the command |
| 81 | + tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16 |
| 82 | +
|
| 83 | +After fine-tuning, users can quantize and evaluate the resulting model as follows. |
| 84 | +This is the same whether or not QAT was used during the fine-tuning process: |
| 85 | + |
| 86 | +.. code:: |
| 87 | +
|
| 88 | + # Quantize model weights to int4 |
| 89 | + tune run quantize --config quantization \ |
| 90 | + model._component_=torchtune.models.llama3_2.llama3_2_3b \ |
| 91 | + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ |
| 92 | + 'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \ |
| 93 | + checkpointer.model_type=LLAMA3 \ |
| 94 | + quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ |
| 95 | + quantizer.groupsize=32 |
| 96 | +
|
| 97 | + # Evaluate the int4 model on hellaswag and wikitext |
| 98 | + tune run eleuther_eval --config eleuther_evaluation \ |
| 99 | + batch_size=1 \ |
| 100 | + 'tasks=[hellaswag, wikitext]' \ |
| 101 | + model._component_=torchtune.models.llama3_2.llama3_2_3b \ |
| 102 | + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ |
| 103 | + 'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \ |
| 104 | + checkpointer.model_type=LLAMA3 \ |
| 105 | + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ |
| 106 | + tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ |
| 107 | + quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ |
| 108 | + quantizer.groupsize=32 |
| 109 | +
|
| 110 | +This should print the following after fine-tuning: |
| 111 | + |
| 112 | +.. code:: |
| 113 | +
|
| 114 | + | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |
| 115 | + |---------|------:|------|------|--------|---|-----:|---|-----:| |
| 116 | + |hellaswag| 1|none |None |acc |↑ |0.5021|± |0.0050| |
| 117 | + | | |none |None |acc_norm|↑ |0.6797|± |0.0047| |
| 118 | +
|
| 119 | + | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |
| 120 | + |--------|------:|------|------|---------------|---|------:|---|------| |
| 121 | + |wikitext| 2|none |None |bits_per_byte |↓ | 0.6965|± | N/A| |
| 122 | + | | |none |None |byte_perplexity|↓ | 1.6206|± | N/A| |
| 123 | + | | |none |None |word_perplexity|↓ |13.2199|± | N/A| |
| 124 | +
|
| 125 | +You can compare these values with and without QAT to see how much QAT helped mitigate quantization degradation! |
| 126 | +For example, when fine-tuning Llama-3.2-3B on the |
| 127 | +`OpenAssistant Conversations (OASST1) <https://huggingface.co/datasets/OpenAssistant/oasst1>`__ |
| 128 | +dataset, we find that the quantized model achieved 3.4% higher accuracy |
| 129 | +with QAT than without, recovering 69.8% of the overall accuracy degradation |
| 130 | +from quantization: |
| 131 | + |
| 132 | +.. image:: ../static/qat_eval.png |
| 133 | + |
| 134 | +In addition to vanilla QAT as in the above example, TorchAO's QAT can also be composed with LoRA to yield a `1.89x training speedup <https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700>`__. This is implemented in TorchTune's `QAT + LoRA fine-tuning recipe <https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py>`__, which can be run using the following command: |
| 135 | + |
| 136 | +.. code:: |
| 137 | +
|
| 138 | + # Fine-tuning with QAT + LoRA |
| 139 | + tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora batch_size=16 |
| 140 | +
|
| 141 | +For more details about how QAT is set up in TorchTune, please refer to `this tutorial <https://docs.pytorch.org/torchtune/main/tutorials/qat_finetune.html>`__. |
| 142 | + |
| 143 | + |
| 144 | +Option 2: Axolotl QAT Integration |
| 145 | +================================= |
| 146 | + |
| 147 | +Axolotl also recently added a QAT fine-tuning recipe that leverages TorchAO's QAT support. |
| 148 | +To get started, try fine-tuning Llama-3.2-3B with QAT using the following command: |
| 149 | + |
| 150 | +.. code:: |
| 151 | +
|
| 152 | + axolotl train examples/llama-3/3b-qat-fsdp2.yaml |
| 153 | + # once training is complete, perform the quantization step |
| 154 | +
|
| 155 | + axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml |
| 156 | + # you should now have a quantized model saved in ./outputs/qat_out/quatized |
| 157 | +
|
| 158 | +Please refer to the `Axolotl QAT documentation <https://docs.axolotl.ai/docs/qat.html>`__ for full details. |
| 159 | + |
| 160 | + |
| 161 | +Option 3: TorchAO QAT API |
| 162 | +========================= |
| 163 | + |
| 164 | +If you prefer to use a different training framework or your own custom training loop, |
| 165 | +you can call TorchAO's QAT APIs directly to transform the model before fine-tuning. |
| 166 | +These APIs are what the TorchTune and Axolotl QAT integrations call under the hood. |
| 167 | + |
| 168 | +In this example, we will fine-tune a mini version of Llama3 on a single GPU: |
| 169 | + |
| 170 | +.. code:: py |
| 171 | +
|
| 172 | + import torch |
| 173 | + from torchtune.models.llama3 import llama3 |
| 174 | +
|
| 175 | + # Set up a smaller version of llama3 to fit in a single A100 GPU |
| 176 | + # For smaller GPUs, adjust the model attributes accordingly |
| 177 | + def get_model(): |
| 178 | + return llama3( |
| 179 | + vocab_size=4096, |
| 180 | + num_layers=16, |
| 181 | + num_heads=16, |
| 182 | + num_kv_heads=4, |
| 183 | + embed_dim=2048, |
| 184 | + max_seq_len=2048, |
| 185 | + ).cuda() |
| 186 | +
|
| 187 | + # Example training loop |
| 188 | + def train_loop(m: torch.nn.Module): |
| 189 | + optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) |
| 190 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 191 | + for i in range(10): |
| 192 | + example = torch.randint(0, 4096, (2, 16)).cuda() |
| 193 | + target = torch.randn((2, 16, 4096)).cuda() |
| 194 | + output = m(example) |
| 195 | + loss = loss_fn(output, target) |
| 196 | + loss.backward() |
| 197 | + optimizer.step() |
| 198 | + optimizer.zero_grad() |
| 199 | +
|
| 200 | +Next, run the prepare step, which fake quantizes the model. In this example, |
| 201 | +we use int8 per token dynamic activations and int4 symmetric per group weights |
| 202 | +as our quantization scheme. Note that although we are targeting lower integer |
| 203 | +precisions, training still performs arithmetic in higher float precision (float32) |
| 204 | +because we are not actually casting the fake quantized values. |
| 205 | + |
| 206 | +.. code:: py |
| 207 | +
|
| 208 | + from torchao.quantization import ( |
| 209 | + quantize_, |
| 210 | + ) |
| 211 | + from torchao.quantization.qat import ( |
| 212 | + FakeQuantizeConfig, |
| 213 | + IntXQuantizationAwareTrainingConfig, |
| 214 | + ) |
| 215 | + model = get_model() |
| 216 | +
|
| 217 | + # prepare: insert fake quantization ops |
| 218 | + # swaps `torch.nn.Linear` with `FakeQuantizedLinear` |
| 219 | + activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) |
| 220 | + weight_config = FakeQuantizeConfig(torch.int4, group_size=32) |
| 221 | + qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config) |
| 222 | + quantize_(model, qat_config) |
| 223 | +
|
| 224 | + # fine-tune |
| 225 | + train_loop(model) |
| 226 | +
|
| 227 | +After fine-tuning, we end up with a model in the original high precision. |
| 228 | +This fine-tuned model has the exact same structure as the original model. |
| 229 | +The only difference is the QAT fine-tuned model has weights that are more |
| 230 | +attuned to quantization, which will be beneficial later during inference. |
| 231 | +The next step is to actually quantize the model: |
| 232 | + |
| 233 | +.. code:: py |
| 234 | +
|
| 235 | + from torchao.quantization import ( |
| 236 | + Int8DynamicActivationInt4WeightConfig, |
| 237 | + ) |
| 238 | + from torchao.quantization.qat import ( |
| 239 | + FromIntXQuantizationAwareTrainingConfig, |
| 240 | + ) |
| 241 | +
|
| 242 | + # convert: transform fake quantization ops into actual quantized ops |
| 243 | + # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts |
| 244 | + # quantized activation and weight tensor subclasses |
| 245 | + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) |
| 246 | + quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) |
| 247 | +
|
| 248 | +Now our model is ready for serving, and will typically have higher quantized |
| 249 | +accuracy than if we did not apply the prepare step (fake quantization) during |
| 250 | +fine-tuning. |
| 251 | + |
| 252 | +For full details of using TorchAO's QAT API, please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md>`__. |
| 253 | + |
| 254 | +.. raw:: html |
| 255 | + |
| 256 | + <details> |
| 257 | + <summary><a>Alternative Legacy API</a></summary> |
| 258 | + |
| 259 | +The above `quantize_` API is the recommended flow for using TorchAO QAT. |
| 260 | +We also offer an alternative legacy "quantizer" API for specific quantization |
| 261 | +schemes, but these are not customizable unlike the above example. |
| 262 | + |
| 263 | +.. code:: |
| 264 | +
|
| 265 | + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer |
| 266 | + qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32) |
| 267 | +
|
| 268 | + # prepare: insert fake quantization ops |
| 269 | + # swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear` |
| 270 | + model = qat_quantizer.prepare(model) |
| 271 | +
|
| 272 | + # train |
| 273 | + train_loop(model) |
| 274 | +
|
| 275 | + # convert: transform fake quantization ops into actual quantized ops |
| 276 | + # swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear` |
| 277 | + model = qat_quantizer.convert(model) |
| 278 | +
|
| 279 | +.. raw:: html |
| 280 | + |
| 281 | + </details> |
| 282 | + |
| 283 | + |
| 284 | +Quantized Low-Rank Adaptation (QLoRA) |
| 285 | +##################################### |
| 286 | + |
| 287 | +(Coming soon!) |
| 288 | + |
| 289 | + |
| 290 | +Float8 Quantized Fine-tuning |
| 291 | +############################ |
| 292 | + |
| 293 | +(Coming soon!) |
0 commit comments