|
| 1 | +Pretraining with float8 |
| 2 | +--------------------------------- |
| 3 | + |
| 4 | +Pretraining with float8 using torchao can provide `up to 1.5x speedups <https://pytorch.org/blog/training-using-float8-fsdp2/>`__ on 512 GPU clusters, |
| 5 | +and up to `1.34-1.43x speedups <https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/>`__ on 2K H200 clusters with the latest `torchao.float8` rowwise recipe. |
| 6 | + |
| 7 | +In this tutorial, we will show 2 ways to use the **torchao.float8** recipes for pretraining: |
| 8 | + |
| 9 | +1. :ref:`Pretraining with torchtitan`, the offical PyTorch pretraining framework with native torchao integration. |
| 10 | +2. :ref:`Pretraining with torchao directly`, to integrate torchao's float8 training recipes into your own pretraining code. |
| 11 | + |
| 12 | + |
| 13 | +Pretraining with torchtitan |
| 14 | +########################### |
| 15 | + |
| 16 | +In this tutorial we'll pretrain Llama3 8b using torchtitan with torchao's float8 training recipes: rowwise scaling and tensorwise scaling. |
| 17 | + |
| 18 | +`Torchtitan <https://github.com/pytorch/torchtitan/>`__ is PyTorch's official pretraining framework that is natively integrated with torchao, and supports |
| 19 | +several popular flagship models with common forms of parallelism, float8 training, distributed checkpointing and more. |
| 20 | +See the torchtitan `docs <https://github.com/pytorch/torchtitan>`__ for additional details. |
| 21 | + |
| 22 | +You can use this workflow to get started quickly with a "batteries included" experience. Users commonly |
| 23 | +fork torchtitan and build on top of it when they're ready. |
| 24 | + |
| 25 | +Prerequisites |
| 26 | +================ |
| 27 | + |
| 28 | +1. (Recommended) Create a new virtual environment with conda or venv. |
| 29 | +2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__. |
| 30 | +3. `Install torchtitan <https://github.com/pytorch/torchtitan/tree/main?tab=readme-ov-file#installation>`__, including the "downloading a tokenizer" step. |
| 31 | + |
| 32 | +You're now ready to start a pretraining job using one of the recipes below! |
| 33 | + |
| 34 | +Rowwise scaling |
| 35 | +=============== |
| 36 | + |
| 37 | +Run the following command from torchtitan root directory to launch a Llama3 8b training job on 8 GPUs with float8 rowwise training: |
| 38 | + |
| 39 | +.. code:: console |
| 40 | +
|
| 41 | + NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8" --float8.recipe_name="rowwise" |
| 42 | +
|
| 43 | +Torchtitan will automatically use FSDP2 to parallelize training when more than 1 GPU is used. To use other forms of parallelism, modify hyperparameters, or change other training configurations, you can directly edit the `llama3_8b.toml <https://github.com/pytorch/torchtitan/blob/775a486edd173ceb9be1c9b1b30af6ca2d4b4fa0/torchtitan/models/llama3/train_configs/llama3_8b.toml>`__ file or use command line flags (run the command with :code:`--help` to see more options). |
| 44 | + |
| 45 | +You should see terminal output that looks like this: |
| 46 | + |
| 47 | +.. code:: console |
| 48 | +
|
| 49 | + [rank0]:[titan] 2025-06-04 08:51:48,074 - root - INFO - step: 1 loss: 12.2254 memory: 27.34GiB(28.78%) tps: 375 tflops: 21.73 mfu: 2.20% |
| 50 | + [rank0]:[titan] 2025-06-04 08:51:58,557 - root - INFO - step: 10 loss: 10.7069 memory: 30.99GiB(32.62%) tps: 7,034 tflops: 407.35 mfu: 41.19% |
| 51 | + [rank0]:[titan] 2025-06-04 08:52:10,224 - root - INFO - step: 20 loss: 8.9196 memory: 30.99GiB(32.62%) tps: 7,022 tflops: 406.65 mfu: 41.12% |
| 52 | + [rank0]:[titan] 2025-06-04 08:52:21,904 - root - INFO - step: 30 loss: 8.1423 memory: 30.99GiB(32.62%) tps: 7,014 tflops: 406.23 mfu: 41.08% |
| 53 | +
|
| 54 | +As you can see, ignoring the warmup steps we are achieving around ~7k TPS with 30.99GB peak memory usage. To compare performance against bfloat16 training, you can remove the :code:`--model.converters="float8" --float8.recipe_name="rowwise"` flags |
| 55 | +and run the same command to see the baseline performance of bfloat16 training: |
| 56 | + |
| 57 | +.. code:: console |
| 58 | +
|
| 59 | + NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile |
| 60 | +
|
| 61 | +You should see the following output: |
| 62 | + |
| 63 | +.. code:: console |
| 64 | +
|
| 65 | + [rank0]:[titan] 2025-06-04 11:02:37,404 - root - INFO - step: 1 loss: 12.2611 memory: 27.22GiB(28.65%) tps: 595 tflops: 34.47 mfu: 3.49% |
| 66 | + [rank0]:[titan] 2025-06-04 11:02:49,027 - root - INFO - step: 10 loss: 10.4260 memory: 30.89GiB(32.51%) tps: 6,344 tflops: 367.39 mfu: 37.15% |
| 67 | + [rank0]:[titan] 2025-06-04 11:03:01,988 - root - INFO - step: 20 loss: 8.9482 memory: 30.89GiB(32.51%) tps: 6,321 tflops: 366.06 mfu: 37.01% |
| 68 | + [rank0]:[titan] 2025-06-04 11:03:14,991 - root - INFO - step: 30 loss: 8.1183 memory: 30.89GiB(32.51%) tps: 6,300 tflops: 364.89 mfu: 36.89% |
| 69 | + [rank0]:[titan] 2025-06-04 11:03:28,013 - root - INFO - step: 40 loss: 7.4659 memory: 30.89GiB(32.51%) tps: 6,291 tflops: 364.36 mfu: 36.84% |
| 70 | + [rank0]:[titan] 2025-06-04 11:03:39,769 - root - INFO - [GC] Peforming periodical GC collection. 0.02 seconds. |
| 71 | +
|
| 72 | +As you can see, the bfloat16 baseline achieves ~6.3k TPS using 30.89GB peak memory. |
| 73 | + |
| 74 | +This means our float8 rowwise scaling recipe achieves **1.11x higher throughput** compared to bfloat16 baseline, using nearly identical peak memory! |
| 75 | + |
| 76 | +Note that you can achieve even higher throughput improvement using the tensorwise scaling recipe, which exists at a different point on the performane vs accuracy curve. |
| 77 | + |
| 78 | +Tensorwise scaling |
| 79 | +================== |
| 80 | + |
| 81 | +Float8 training with tensorwise scaling is the default recipe, so we can omit the :code:`--float8.recipe_name` flag: |
| 82 | + |
| 83 | +.. code:: console |
| 84 | +
|
| 85 | + NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8" |
| 86 | +
|
| 87 | +You should see the output like the following: |
| 88 | + |
| 89 | +.. code:: console |
| 90 | +
|
| 91 | + [rank0]:[titan] 2025-06-04 10:52:19,648 - root - INFO - step: 1 loss: 12.2648 memory: 27.28GiB(28.71%) tps: 557 tflops: 32.29 mfu: 3.26% |
| 92 | + [rank0]:[titan] 2025-06-04 10:52:29,475 - root - INFO - step: 10 loss: 10.9106 memory: 30.91GiB(32.53%) tps: 7,503 tflops: 434.53 mfu: 43.94% |
| 93 | + [rank0]:[titan] 2025-06-04 10:52:40,166 - root - INFO - step: 20 loss: 9.0774 memory: 30.91GiB(32.53%) tps: 7,663 tflops: 443.78 mfu: 44.87% |
| 94 | + [rank0]:[titan] 2025-06-04 10:52:50,885 - root - INFO - step: 30 loss: 8.3233 memory: 30.91GiB(32.53%) tps: 7,643 tflops: 442.66 mfu: 44.76% |
| 95 | + [rank0]:[titan] 2025-06-04 10:53:01,613 - root - INFO - step: 40 loss: 7.6150 memory: 30.91GiB(32.53%) tps: 7,637 tflops: 442.27 mfu: 44.72% |
| 96 | +
|
| 97 | +As you can see, we are achieving ~7.6k TPS using 30.91GB peak memory, which is **1.21x higher throughput** compared to the bfloat16 baseline! |
| 98 | + |
| 99 | +Picking a recipe |
| 100 | +================ |
| 101 | + |
| 102 | +**TL;DR**: rowwise scaling is better for jobs prioritizing more accurate numerics and training stability, and tensorwise is better for jobs prioritizing training throughput. |
| 103 | + |
| 104 | +The higher throughput of tensorwise scaling comes at the cost of slightly higher quantization error (i.e., reduced numerical integrity vs bfloat16) compared to rowwise scaling. |
| 105 | +This is because rowwise scaling using a more granular scaling factor (per row, instead of per tensor), which limits the impact of outliers that can cause underflow during scaling. |
| 106 | + |
| 107 | +Below you can see the loss curves comparing bfloat16, float8 tensorwise, and float8 rowwise training for training Llama3 8b on 8xH100 GPUs: |
| 108 | + |
| 109 | +.. image:: ../static/fp8-loss-curves.png |
| 110 | + :alt: Loss curves for training Llama3 8b on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training. |
| 111 | + |
| 112 | + |
| 113 | +Important notes |
| 114 | +=============== |
| 115 | + |
| 116 | +* float8 training is currently only supported on 2+ GPUs in torchtitan, not single GPU training. |
| 117 | +* You must use :code:`--training.compile` to achieve high performance. torchao float8 training recipes are built natively on top of :code:`torch.compile`, so it will work out of the box! |
| 118 | + |
| 119 | + |
| 120 | +Pretraining with torchao directly |
| 121 | +################################# |
| 122 | + |
| 123 | +In this tutorial we'll pretrain a toy model using torchao APIs directly. |
| 124 | + |
| 125 | +You can use this workflow to integrate torchao into your own custom pretraining code directly. |
| 126 | + |
| 127 | +Prerequisites |
| 128 | +================ |
| 129 | + |
| 130 | +1. (Recommended) Create a new virtual environment with conda or venv. |
| 131 | +2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__. |
| 132 | + |
| 133 | +You're now ready to integrate torchao into your training code directly! |
| 134 | + |
| 135 | +Model conversion API |
| 136 | +==================== |
| 137 | + |
| 138 | +The torchao API for converting your model to use float8 training is: `convert_to_float8_training <https://github.com/pytorch/ao/blob/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8/float8_linear_utils.py#L84>`__. This API will recursively convert :code:`nn.Linear` modules in your model to use `Float8Linear <https://github.com/pytorch/ao/blob/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8/float8_linear.py#L254>`__. |
| 139 | + |
| 140 | +You can use the :code:`module_filter_fn` argument to determine which :code:`nn.Linear` layers should be swapped to use :code:`Float8Linear`. |
| 141 | + |
| 142 | +You should refer to this `performance benchmark table <https://github.com/pytorch/ao/tree/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8#performance>`__ to understand |
| 143 | +what kind of performance improvement over bfloat16 you can expect for a given GEMM size. |
| 144 | + |
| 145 | +Below is a code snippet showing how to use it: |
| 146 | + |
| 147 | +.. code:: py |
| 148 | +
|
| 149 | + import torch |
| 150 | + from torch import nn |
| 151 | + import torch.nn.functional as F |
| 152 | +
|
| 153 | + from torchao.float8.float8_linear_utils import convert_to_float8_training |
| 154 | + from torchao.float8.float8_linear import Float8Linear |
| 155 | + from torchao.float8 import convert_to_float8_training |
| 156 | + from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 157 | +
|
| 158 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 159 | + raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") |
| 160 | +
|
| 161 | + # create model and sample input |
| 162 | + m = nn.Sequential( |
| 163 | + nn.Linear(2048, 4096), |
| 164 | + nn.Linear(4096, 128), |
| 165 | + nn.Linear(128, 1), |
| 166 | + ).bfloat16().cuda() |
| 167 | + x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) |
| 168 | + optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3) |
| 169 | +
|
| 170 | + # optional: filter modules from being eligible for float8 conversion |
| 171 | + def module_filter_fn(mod: torch.nn.Module, fqn: str): |
| 172 | + # don't convert the last module |
| 173 | + if fqn == "1": |
| 174 | + return False |
| 175 | + # don't convert linear modules with weight dimensions not divisible by 16 |
| 176 | + if isinstance(mod, torch.nn.Linear): |
| 177 | + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: |
| 178 | + return False |
| 179 | + return True |
| 180 | +
|
| 181 | + # convert specified `torch.nn.Linear` modules to `Float8Linear` |
| 182 | + convert_to_float8_training(m, module_filter_fn=module_filter_fn) |
| 183 | +
|
| 184 | + # enable torch.compile for competitive performance |
| 185 | + m = torch.compile(m) |
| 186 | +
|
| 187 | + # toy training loop |
| 188 | + for _ in range(10): |
| 189 | + optimizer.zero_grad() |
| 190 | + output = m(x) |
| 191 | + # use fake labels for demonstration purposes |
| 192 | + fake_labels = torch.ones_like(output) |
| 193 | + loss = F.mse_loss(output, fake_labels) |
| 194 | + loss.backward() |
| 195 | + optimizer.step() |
| 196 | +
|
| 197 | + # save the model |
| 198 | + torch.save({ |
| 199 | + 'model': m, |
| 200 | + 'model_state_dict': m.state_dict(), |
| 201 | + 'optimizer_state_dict': optimizer.state_dict(), |
| 202 | + }, 'checkpoint.pth') |
0 commit comments