Skip to content

Commit 934d11e

Browse files
authored
refresh MX readme (#1989)
1 parent a010e62 commit 934d11e

File tree

1 file changed

+75
-58
lines changed

1 file changed

+75
-58
lines changed
Lines changed: 75 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,24 @@
1-
# MX formats with native PyTorch POC
1+
# MX training and inference with native PyTorch
22

3-
This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch.
3+
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4+
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
45

5-
Note that the current version of the code is written for readability and
6-
numerical correctness and not yet for optimal performance. We welcome
7-
contributions on performance improvements.
6+
## Overall status
87

9-
Note that there are no BC guarantees at the moment and we plan to evolve
10-
this code as the hardware specifics of MX-accelerated matmuls become
11-
known.
8+
| workflow | emulation | performance | accuracy |
9+
| --- | --- | --- | --- |
10+
| training with mxfp8 || 🚧 [active development](https://github.com/pytorch/ao/issues/1768) ||
11+
| inference (weight-only) with mxfp8, mxfp6, mxfp4 || 🔲 | 🔲 |
1212

13-
# Current status
13+
We plan to add the following features in the near future:
14+
* other inference workflows such as dynamic quantization
15+
* a unified training to inference workflow
1416

15-
## user API (subject to change)
17+
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
1618

17-
### MXTensor
19+
# User API
1820

19-
This is casts between high precision and MX formats implemented in native PyTorch. Currently
20-
only `torch.float32` and `torch.bfloat16` are supported as high precision formats.
21-
22-
```python
23-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
24-
# Note: MX int8 is not implemented yet
25-
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
26-
x = torch.randn(32, 32, device='cuda')
27-
28-
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
29-
elem_dtype = torch.float8_e4m3fn
30-
31-
# high precision to MX, block size defaults to 32
32-
x_mx = MXTensor.to_mx(x, elem_dtype)
33-
34-
# mx back to high precision
35-
x_hp = x_mx.to_dtype(torch.float)
36-
```
37-
38-
### MXLinear
39-
40-
This is a module to do MX training, the MX matmul is currently emulated.
21+
## MX training
4122

4223
```python
4324
import torch
@@ -62,9 +43,9 @@ quantize_(m, config)
6243
# training loop (not shown)
6344
```
6445

65-
### MXInferenceLinear
46+
## MX inference
6647

67-
This is a module to do MX inference, weights are in MX and matmul is in high precision.
48+
Note: currently only weight-only quantization is supported.
6849

6950
```python
7051
import torch
@@ -82,39 +63,75 @@ quantize_(m, config=config)
8263

8364
# do inference (not shown)
8465
```
66+
## MXTensor
8567

86-
## accuracy status
87-
* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve
88-
* approximate numerics pass for `MXLinear` and `MXInferenceLinear` on sample inputs
89-
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo)
68+
This is casts between high precision and MX formats implemented in native PyTorch. Currently
69+
only `torch.float32` and `torch.bfloat16` are supported as high precision formats.
70+
71+
```python
72+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
73+
# Note: MX int8 is not implemented yet
74+
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
75+
x = torch.randn(32, 32, device='cuda')
76+
77+
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
78+
elem_dtype = torch.float8_e4m3fn
79+
80+
# high precision to MX, block size defaults to 32
81+
x_mx = MXTensor.to_mx(x, elem_dtype)
9082

91-
## performance status
83+
# mx back to high precision
84+
x_hp = x_mx.to_dtype(torch.float)
85+
```
9286

93-
### quant and dequant
87+
# performance
9488

95-
* we have a benchmark of quantizing and dequantizing mxfp8 and mxfp4 tensors with size (1, 4096, 11008)
96-
* latest numbers: https://gist.github.com/vkuzo/83656e4a74777cfc0915de6b27be1ff6
89+
## mxfp8 gemm
9790

98-
## testing and benchmarking
91+
On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
92+
We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this
93+
on supported hardware, you can run the following command:
9994

10095
```bash
101-
# numerical testing of custom fp4 and fp6 casts
102-
pytest test/prototype/mx_formats/test_custom_cast.py
103-
# testing of MXTensor
104-
pytest test/prototype/mx_formats/test_mx_tensor.py
105-
# testing of MXLinear and MXInferenceLinear
106-
pytest test/prototype/mx_formats/test_mx_linear.py
107-
108-
# run the quant and dequant benchmark
109-
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py
96+
> python benchmarks/float8/bench_matmul.py --recipe mxfp8_cublas
97+
// example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc
11098
```
11199

112-
## floating point format convenience functions
100+
## to_mx cast across dim0 and dim1
101+
102+
On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile),
103+
and **up to 3.9 TB/s** for the dim1 cast (with a triton kernel). We are actively working on improving
104+
the performance of this cast ([details](https://github.com/pytorch/ao/issues/1768)).
105+
106+
To reproduce this on supported hardware, you can run the following command:
107+
108+
```bash
109+
// dim0 cast with torch.compile
110+
> python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
111+
// example output: https://gist.github.com/vkuzo/06aae58de9b8aae02c82adb00eb33197
112+
113+
// dim1 cast with a handwritten triton kernel
114+
> python benchmarks/mx_formats/cast_bench.py --mode dim1_mx_triton --M 16384 --K 16384
115+
// example output: https://gist.github.com/vkuzo/7ac5fce44c9b90bfb9eae2a07b721cda
116+
```
117+
118+
## performance tracker
119+
120+
Please see our [performance tracker](https://github.com/pytorch/ao/issues/1768) for the latest on MX training and inference performance!
121+
122+
# accuracy
123+
124+
## training
125+
126+
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo)
127+
* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve
128+
129+
## inference
130+
131+
Coming soon!
113132

114-
We have a convenience script which summarizes the various properties of
115-
floating point formats:
133+
# testing
116134

117135
```bash
118-
python torchao/prototype/mx_formats/fp_format_spec.py
119-
# example output: https://gist.github.com/vkuzo/b8e114aa83736f87d6618b16aa8588c0
136+
pytest test/prototype/mx_formats/
120137
```

0 commit comments

Comments
 (0)