You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch.
3
+
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4
+
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
4
5
5
-
Note that the current version of the code is written for readability and
6
-
numerical correctness and not yet for optimal performance. We welcome
7
-
contributions on performance improvements.
6
+
## Overall status
8
7
9
-
Note that there are no BC guarantees at the moment and we plan to evolve
10
-
this code as the hardware specifics of MX-accelerated matmuls become
11
-
known.
8
+
| workflow | emulation | performance | accuracy |
9
+
| --- | --- | --- | --- |
10
+
| training with mxfp8 | ✅ | 🚧 [active development](https://github.com/pytorch/ao/issues/1768)| ✅ |
We plan to add the following features in the near future:
14
+
* other inference workflows such as dynamic quantization
15
+
* a unified training to inference workflow
14
16
15
-
## user API (subject to change)
17
+
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
16
18
17
-
### MXTensor
19
+
#User API
18
20
19
-
This is casts between high precision and MX formats implemented in native PyTorch. Currently
20
-
only `torch.float32` and `torch.bfloat16` are supported as high precision formats.
21
-
22
-
```python
23
-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
24
-
# Note: MX int8 is not implemented yet
25
-
from torchao.prototype.mx_formats.constants importDTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
26
-
x = torch.randn(32, 32, device='cuda')
27
-
28
-
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
29
-
elem_dtype = torch.float8_e4m3fn
30
-
31
-
# high precision to MX, block size defaults to 32
32
-
x_mx = MXTensor.to_mx(x, elem_dtype)
33
-
34
-
# mx back to high precision
35
-
x_hp = x_mx.to_dtype(torch.float)
36
-
```
37
-
38
-
### MXLinear
39
-
40
-
This is a module to do MX training, the MX matmul is currently emulated.
21
+
## MX training
41
22
42
23
```python
43
24
import torch
@@ -62,9 +43,9 @@ quantize_(m, config)
62
43
# training loop (not shown)
63
44
```
64
45
65
-
### MXInferenceLinear
46
+
##MX inference
66
47
67
-
This is a module to do MX inference, weights are in MX and matmul is in high precision.
48
+
Note: currently only weight-only quantization is supported.
68
49
69
50
```python
70
51
import torch
@@ -82,39 +63,75 @@ quantize_(m, config=config)
82
63
83
64
# do inference (not shown)
84
65
```
66
+
## MXTensor
85
67
86
-
## accuracy status
87
-
* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve
88
-
* approximate numerics pass for `MXLinear` and `MXInferenceLinear` on sample inputs
89
-
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo)
68
+
This is casts between high precision and MX formats implemented in native PyTorch. Currently
69
+
only `torch.float32` and `torch.bfloat16` are supported as high precision formats.
70
+
71
+
```python
72
+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
73
+
# Note: MX int8 is not implemented yet
74
+
from torchao.prototype.mx_formats.constants importDTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
75
+
x = torch.randn(32, 32, device='cuda')
76
+
77
+
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
78
+
elem_dtype = torch.float8_e4m3fn
79
+
80
+
# high precision to MX, block size defaults to 32
81
+
x_mx = MXTensor.to_mx(x, elem_dtype)
90
82
91
-
## performance status
83
+
# mx back to high precision
84
+
x_hp = x_mx.to_dtype(torch.float)
85
+
```
92
86
93
-
### quant and dequant
87
+
#performance
94
88
95
-
* we have a benchmark of quantizing and dequantizing mxfp8 and mxfp4 tensors with size (1, 4096, 11008)
0 commit comments