Skip to content

Commit 488ecd4

Browse files
[BE] [docs] Add float8 pretraining tutorial to docsite (#2304)
* add float8 pretraining tutorial * make empty commit to trigger ci * remove references to e2e tutorial
1 parent 0d9631b commit 488ecd4

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# ones.
3838
extensions = [
3939
"sphinx.ext.autodoc",
40+
"sphinx.ext.autosectionlabel",
4041
"sphinx.ext.autosummary",
4142
"sphinx.ext.doctest",
4243
"sphinx.ext.intersphinx",

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ for an overall introduction to the library and recent highlight and updates.
4040
serialization
4141
subclass_basic
4242
subclass_advanced
43+
pretraining

docs/source/pretraining.rst

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
Pretraining with float8
2+
---------------------------------
3+
4+
Pretraining with float8 using torchao can provide `up to 1.5x speedups <https://pytorch.org/blog/training-using-float8-fsdp2/>`__ on 512 GPU clusters,
5+
and up to `1.34-1.43x speedups <https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/>`__ on 2K H200 clusters with the latest `torchao.float8` rowwise recipe.
6+
7+
In this tutorial, we will show 2 ways to use the **torchao.float8** recipes for pretraining:
8+
9+
1. :ref:`Pretraining with torchtitan`, the offical PyTorch pretraining framework with native torchao integration.
10+
2. :ref:`Pretraining with torchao directly`, to integrate torchao's float8 training recipes into your own pretraining code.
11+
12+
13+
Pretraining with torchtitan
14+
###########################
15+
16+
In this tutorial we'll pretrain Llama3 8b using torchtitan with torchao's float8 training recipes: rowwise scaling and tensorwise scaling.
17+
18+
`Torchtitan <https://github.com/pytorch/torchtitan/>`__ is PyTorch's official pretraining framework that is natively integrated with torchao, and supports
19+
several popular flagship models with common forms of parallelism, float8 training, distributed checkpointing and more.
20+
See the torchtitan `docs <https://github.com/pytorch/torchtitan>`__ for additional details.
21+
22+
You can use this workflow to get started quickly with a "batteries included" experience. Users commonly
23+
fork torchtitan and build on top of it when they're ready.
24+
25+
Prerequisites
26+
================
27+
28+
1. (Recommended) Create a new virtual environment with conda or venv.
29+
2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__.
30+
3. `Install torchtitan <https://github.com/pytorch/torchtitan/tree/main?tab=readme-ov-file#installation>`__, including the "downloading a tokenizer" step.
31+
32+
You're now ready to start a pretraining job using one of the recipes below!
33+
34+
Rowwise scaling
35+
===============
36+
37+
Run the following command from torchtitan root directory to launch a Llama3 8b training job on 8 GPUs with float8 rowwise training:
38+
39+
.. code:: console
40+
41+
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8" --float8.recipe_name="rowwise"
42+
43+
Torchtitan will automatically use FSDP2 to parallelize training when more than 1 GPU is used. To use other forms of parallelism, modify hyperparameters, or change other training configurations, you can directly edit the `llama3_8b.toml <https://github.com/pytorch/torchtitan/blob/775a486edd173ceb9be1c9b1b30af6ca2d4b4fa0/torchtitan/models/llama3/train_configs/llama3_8b.toml>`__ file or use command line flags (run the command with :code:`--help` to see more options).
44+
45+
You should see terminal output that looks like this:
46+
47+
.. code:: console
48+
49+
[rank0]:[titan] 2025-06-04 08:51:48,074 - root - INFO - step: 1 loss: 12.2254 memory: 27.34GiB(28.78%) tps: 375 tflops: 21.73 mfu: 2.20%
50+
[rank0]:[titan] 2025-06-04 08:51:58,557 - root - INFO - step: 10 loss: 10.7069 memory: 30.99GiB(32.62%) tps: 7,034 tflops: 407.35 mfu: 41.19%
51+
[rank0]:[titan] 2025-06-04 08:52:10,224 - root - INFO - step: 20 loss: 8.9196 memory: 30.99GiB(32.62%) tps: 7,022 tflops: 406.65 mfu: 41.12%
52+
[rank0]:[titan] 2025-06-04 08:52:21,904 - root - INFO - step: 30 loss: 8.1423 memory: 30.99GiB(32.62%) tps: 7,014 tflops: 406.23 mfu: 41.08%
53+
54+
As you can see, ignoring the warmup steps we are achieving around ~7k TPS with 30.99GB peak memory usage. To compare performance against bfloat16 training, you can remove the :code:`--model.converters="float8" --float8.recipe_name="rowwise"` flags
55+
and run the same command to see the baseline performance of bfloat16 training:
56+
57+
.. code:: console
58+
59+
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile
60+
61+
You should see the following output:
62+
63+
.. code:: console
64+
65+
[rank0]:[titan] 2025-06-04 11:02:37,404 - root - INFO - step: 1 loss: 12.2611 memory: 27.22GiB(28.65%) tps: 595 tflops: 34.47 mfu: 3.49%
66+
[rank0]:[titan] 2025-06-04 11:02:49,027 - root - INFO - step: 10 loss: 10.4260 memory: 30.89GiB(32.51%) tps: 6,344 tflops: 367.39 mfu: 37.15%
67+
[rank0]:[titan] 2025-06-04 11:03:01,988 - root - INFO - step: 20 loss: 8.9482 memory: 30.89GiB(32.51%) tps: 6,321 tflops: 366.06 mfu: 37.01%
68+
[rank0]:[titan] 2025-06-04 11:03:14,991 - root - INFO - step: 30 loss: 8.1183 memory: 30.89GiB(32.51%) tps: 6,300 tflops: 364.89 mfu: 36.89%
69+
[rank0]:[titan] 2025-06-04 11:03:28,013 - root - INFO - step: 40 loss: 7.4659 memory: 30.89GiB(32.51%) tps: 6,291 tflops: 364.36 mfu: 36.84%
70+
[rank0]:[titan] 2025-06-04 11:03:39,769 - root - INFO - [GC] Peforming periodical GC collection. 0.02 seconds.
71+
72+
As you can see, the bfloat16 baseline achieves ~6.3k TPS using 30.89GB peak memory.
73+
74+
This means our float8 rowwise scaling recipe achieves **1.11x higher throughput** compared to bfloat16 baseline, using nearly identical peak memory!
75+
76+
Note that you can achieve even higher throughput improvement using the tensorwise scaling recipe, which exists at a different point on the performane vs accuracy curve.
77+
78+
Tensorwise scaling
79+
==================
80+
81+
Float8 training with tensorwise scaling is the default recipe, so we can omit the :code:`--float8.recipe_name` flag:
82+
83+
.. code:: console
84+
85+
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8"
86+
87+
You should see the output like the following:
88+
89+
.. code:: console
90+
91+
[rank0]:[titan] 2025-06-04 10:52:19,648 - root - INFO - step: 1 loss: 12.2648 memory: 27.28GiB(28.71%) tps: 557 tflops: 32.29 mfu: 3.26%
92+
[rank0]:[titan] 2025-06-04 10:52:29,475 - root - INFO - step: 10 loss: 10.9106 memory: 30.91GiB(32.53%) tps: 7,503 tflops: 434.53 mfu: 43.94%
93+
[rank0]:[titan] 2025-06-04 10:52:40,166 - root - INFO - step: 20 loss: 9.0774 memory: 30.91GiB(32.53%) tps: 7,663 tflops: 443.78 mfu: 44.87%
94+
[rank0]:[titan] 2025-06-04 10:52:50,885 - root - INFO - step: 30 loss: 8.3233 memory: 30.91GiB(32.53%) tps: 7,643 tflops: 442.66 mfu: 44.76%
95+
[rank0]:[titan] 2025-06-04 10:53:01,613 - root - INFO - step: 40 loss: 7.6150 memory: 30.91GiB(32.53%) tps: 7,637 tflops: 442.27 mfu: 44.72%
96+
97+
As you can see, we are achieving ~7.6k TPS using 30.91GB peak memory, which is **1.21x higher throughput** compared to the bfloat16 baseline!
98+
99+
Picking a recipe
100+
================
101+
102+
**TL;DR**: rowwise scaling is better for jobs prioritizing more accurate numerics and training stability, and tensorwise is better for jobs prioritizing training throughput.
103+
104+
The higher throughput of tensorwise scaling comes at the cost of slightly higher quantization error (i.e., reduced numerical integrity vs bfloat16) compared to rowwise scaling.
105+
This is because rowwise scaling using a more granular scaling factor (per row, instead of per tensor), which limits the impact of outliers that can cause underflow during scaling.
106+
107+
Below you can see the loss curves comparing bfloat16, float8 tensorwise, and float8 rowwise training for training Llama3 8b on 8xH100 GPUs:
108+
109+
.. image:: ../static/fp8-loss-curves.png
110+
:alt: Loss curves for training Llama3 8b on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training.
111+
112+
113+
Important notes
114+
===============
115+
116+
* float8 training is currently only supported on 2+ GPUs in torchtitan, not single GPU training.
117+
* You must use :code:`--training.compile` to achieve high performance. torchao float8 training recipes are built natively on top of :code:`torch.compile`, so it will work out of the box!
118+
119+
120+
Pretraining with torchao directly
121+
#################################
122+
123+
In this tutorial we'll pretrain a toy model using torchao APIs directly.
124+
125+
You can use this workflow to integrate torchao into your own custom pretraining code directly.
126+
127+
Prerequisites
128+
================
129+
130+
1. (Recommended) Create a new virtual environment with conda or venv.
131+
2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__.
132+
133+
You're now ready to integrate torchao into your training code directly!
134+
135+
Model conversion API
136+
====================
137+
138+
The torchao API for converting your model to use float8 training is: `convert_to_float8_training <https://github.com/pytorch/ao/blob/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8/float8_linear_utils.py#L84>`__. This API will recursively convert :code:`nn.Linear` modules in your model to use `Float8Linear <https://github.com/pytorch/ao/blob/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8/float8_linear.py#L254>`__.
139+
140+
You can use the :code:`module_filter_fn` argument to determine which :code:`nn.Linear` layers should be swapped to use :code:`Float8Linear`.
141+
142+
You should refer to this `performance benchmark table <https://github.com/pytorch/ao/tree/152a8e397e1383c55bf7b87a8eaa2b87ffb2c114/torchao/float8#performance>`__ to understand
143+
what kind of performance improvement over bfloat16 you can expect for a given GEMM size.
144+
145+
Below is a code snippet showing how to use it:
146+
147+
.. code:: py
148+
149+
import torch
150+
from torch import nn
151+
import torch.nn.functional as F
152+
153+
from torchao.float8.float8_linear_utils import convert_to_float8_training
154+
from torchao.float8.float8_linear import Float8Linear
155+
from torchao.float8 import convert_to_float8_training
156+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
157+
158+
if not TORCH_VERSION_AT_LEAST_2_5:
159+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
160+
161+
# create model and sample input
162+
m = nn.Sequential(
163+
nn.Linear(2048, 4096),
164+
nn.Linear(4096, 128),
165+
nn.Linear(128, 1),
166+
).bfloat16().cuda()
167+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
168+
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
169+
170+
# optional: filter modules from being eligible for float8 conversion
171+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
172+
# don't convert the last module
173+
if fqn == "1":
174+
return False
175+
# don't convert linear modules with weight dimensions not divisible by 16
176+
if isinstance(mod, torch.nn.Linear):
177+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
178+
return False
179+
return True
180+
181+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
182+
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
183+
184+
# enable torch.compile for competitive performance
185+
m = torch.compile(m)
186+
187+
# toy training loop
188+
for _ in range(10):
189+
optimizer.zero_grad()
190+
output = m(x)
191+
# use fake labels for demonstration purposes
192+
fake_labels = torch.ones_like(output)
193+
loss = F.mse_loss(output, fake_labels)
194+
loss.backward()
195+
optimizer.step()
196+
197+
# save the model
198+
torch.save({
199+
'model': m,
200+
'model_state_dict': m.state_dict(),
201+
'optimizer_state_dict': optimizer.state_dict(),
202+
}, 'checkpoint.pth')

docs/static/fp8-loss-curves.png

135 KB
Loading

0 commit comments

Comments
 (0)