Skip to content

Commit 1c007a4

Browse files
committed
WIP NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 101c039 commit 1c007a4

File tree

6 files changed

+98
-10
lines changed

6 files changed

+98
-10
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,42 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441441
assert sqnr >= SQNR_THRESHOLD, (
442442
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443443
)
444+
445+
446+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
447+
@pytest.mark.skipif(
448+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
449+
)
450+
@pytest.mark.skipif(
451+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
452+
)
453+
@pytest.mark.parametrize("bias", [True, False])
454+
@pytest.mark.parametrize("compile", [True, False])
455+
@torch.no_grad()
456+
@skip_if_rocm("ROCm float4 gemm require gfx950")
457+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
458+
"""
459+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
460+
"""
461+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
462+
m_mx = copy.deepcopy(m)
463+
464+
config = MXFPInferenceConfig(
465+
activation_dtype=torch.float4_e2m1fn_x2,
466+
weight_dtype=torch.float4_e2m1fn_x2,
467+
scale_dtype=torch.float8_e4m3fn,
468+
block_size=16,
469+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
470+
)
471+
quantize_(m_mx, config=config)
472+
if compile:
473+
m_mx = torch.compile(m_mx, fullgraph=True)
474+
475+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
476+
y_ref = m(x)
477+
y_mx = m_mx(x)
478+
sqnr = compute_error(y_ref, y_mx)
479+
SQNR_THRESHOLD = 15.0 # Float4 threshold
480+
assert sqnr >= SQNR_THRESHOLD, (
481+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
482+
)

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
104104
w_elem_dtype,
105105
block_size,
106106
weight_hp.dtype,
107+
None, # scale_dtype
107108
False,
108109
gemm_kernel_choice,
109110
False,
@@ -133,6 +134,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
133134
grad_elem_dtype,
134135
block_size,
135136
grad_output_hp_r.dtype,
137+
None, # scale_dtype
136138
False,
137139
gemm_kernel_choice,
138140
False,
@@ -155,6 +157,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
155157
in_elem_dtype,
156158
block_size,
157159
input_hp_r.dtype,
160+
None, # scale_dtype
158161
False,
159162
gemm_kernel_choice,
160163
False,

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def _addmm_mx_dispatch(
9393
M, K, N = a.shape[0], a.shape[1], b.shape[1]
9494
assert a._data.is_contiguous()
9595
assert b._data.t().is_contiguous()
96-
assert a._block_size == 32, f"Invalid block size {a._block_size}"
97-
assert b._block_size == 32, f"Invalid block size {b._block_size}"
96+
assert a._block_size in [16, 32], f"Invalid block size {a._block_size}"
97+
assert b._block_size in [16, 32], f"Invalid block size {b._block_size}"
9898

9999
a_scale = a._scale_e8m0.view(M, K // a._block_size)
100100
b_scale = b._scale_e8m0.view(N, K // b._block_size)
@@ -176,6 +176,7 @@ def mx_t(func, types, args, kwargs):
176176
old._elem_dtype,
177177
old._block_size,
178178
old._orig_dtype,
179+
old._scale_dtype,
179180
old._use_fp4_custom_triton_dequant_kernel,
180181
old._gemm_kernel_choice,
181182
old._pack_fp6,
@@ -220,6 +221,7 @@ def mx_view_op(func, types, args, kwargs):
220221
args[0]._elem_dtype,
221222
args[0]._block_size,
222223
args[0]._orig_dtype,
224+
args[0]._scale_dtype,
223225
args[0]._use_fp4_custom_triton_dequant_kernel,
224226
args[0]._gemm_kernel_choice,
225227
args[0]._pack_fp6,
@@ -281,6 +283,7 @@ def mx_slice(func, types, args, kwargs):
281283
x._elem_dtype,
282284
x._block_size,
283285
x._orig_dtype,
286+
x._scale_dtype,
284287
x._use_fp4_custom_triton_dequant_kernel,
285288
x._gemm_kernel_choice,
286289
x._pack_fp6,

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import types
88
from dataclasses import dataclass
9-
from typing import Optional
9+
from typing import Literal, Optional, Union
1010

1111
import torch
1212

@@ -27,6 +27,30 @@
2727
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
2828

2929

30+
def _validate_scale_dtype(
31+
block_size: int,
32+
weight_dtype: torch.dtype,
33+
activation_dtype: torch.dtype,
34+
scale_dtype: torch.dtype,
35+
):
36+
"""Validate that the scale dtype is one of the supported float8 types."""
37+
assert scale_dtype in [
38+
torch.float8_e8m0fnu,
39+
torch.float8_e4m3fn,
40+
], f"Unsupported scale_dtype {scale_dtype}, must be float8_e8m0fnu or float8_e4m3fn"
41+
if scale_dtype == torch.float8_e8m0fnu:
42+
_validate_elem_dtype(weight_dtype)
43+
_validate_elem_dtype(activation_dtype)
44+
return
45+
46+
assert (
47+
weight_dtype == activation_dtype and weight_dtype == torch.float4_e2m1fn_x2
48+
), (
49+
f"scale_dtype {scale_dtype} is only supported with weight_dtype {weight_dtype} and activation_dtype {activation_dtype}, got weight_dtype {weight_dtype} and activation_dtype {activation_dtype}"
50+
)
51+
assert block_size == 16, f"For NVFP4, block_size must be 16, got {block_size}"
52+
53+
3054
# Note: This API is extra prototype and will change in the future
3155
@dataclass
3256
class MXFPInferenceConfig(AOBaseConfig):
@@ -61,12 +85,16 @@ class MXFPInferenceConfig(AOBaseConfig):
6185
- MXTensor in torchao.prototype.mx_formats.mx_tensor
6286
"""
6387

64-
block_size: int = 32
88+
block_size: Union[Literal[32], Literal[16]] = 32
6589

66-
# Dtypes for Input and Weights
90+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6791
activation_dtype: torch.dtype = torch.float8_e4m3fn
6892
weight_dtype: torch.dtype = torch.float8_e4m3fn
6993

94+
# Supports float8_e4m3fn, float8_e8m0fnu
95+
# e8m0 for MX and e4m3 for NVFP4 on Cuda compatable devices
96+
scale_dtype: torch.dtype = torch.float8_e8m0fnu
97+
7098
# Which kernel to run for mm
7199
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
72100

@@ -82,6 +110,9 @@ def __post_init__(self):
82110
_validate_gemm_kernel_choice(
83111
self.gemm_kernel_choice, self.block_size, self.weight_dtype
84112
)
113+
_validate_scale_dtype(
114+
self.block_size, self.weight_dtype, self.activation_dtype, self.scale_dtype
115+
)
85116

86117

87118
def _linear_extra_repr(self):
@@ -92,6 +123,7 @@ def _input_activation_quant_func_mxfp(
92123
x: torch.Tensor,
93124
activation_dtype: torch.dtype,
94125
block_size: int,
126+
scale_dtype: Optional[torch.dtype] = None,
95127
scale: Optional[torch.Tensor] = None,
96128
):
97129
""" """
@@ -102,6 +134,7 @@ def _input_activation_quant_func_mxfp(
102134
x,
103135
activation_dtype,
104136
block_size=block_size,
137+
scale_dtype=scale_dtype,
105138
gemm_kernel_choice=None, # Get from weight
106139
pack_fp6=False, # TODO
107140
)
@@ -131,6 +164,7 @@ def _mx_inference_linear_transform(
131164
weight,
132165
weight_dtype,
133166
block_size=config.block_size,
167+
scale_dtype=config.scale_dtype,
134168
gemm_kernel_choice=config.gemm_kernel_choice,
135169
pack_fp6=False, # TODO
136170
)
@@ -139,6 +173,7 @@ def _mx_inference_linear_transform(
139173
input_quant_kwargs = {
140174
"block_size": config.block_size,
141175
"activation_dtype": activation_dtype,
176+
"scale_dtype": config.scale_dtype,
142177
"scale": None,
143178
}
144179

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
from enum import Enum, auto
21-
from typing import Callable, Dict, Union
21+
from typing import Callable, Dict, Optional, Union
2222

2323
import torch
2424

@@ -146,6 +146,7 @@ def to_mx(
146146
data_hp: torch.Tensor,
147147
elem_dtype: Union[torch.dtype, str],
148148
block_size: int,
149+
scale_dtype: Optional[torch.dtype] = None,
149150
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
150151
pack_fp6: bool = False,
151152
):
@@ -473,6 +474,7 @@ def __new__(
473474
elem_dtype,
474475
block_size,
475476
orig_dtype,
477+
scale_dtype,
476478
use_fp4_custom_triton_dequant_kernel,
477479
gemm_kernel_choice,
478480
pack_fp6,
@@ -544,6 +546,7 @@ def __new__(
544546
self._elem_dtype = elem_dtype
545547
self._block_size = block_size
546548
self._orig_dtype = orig_dtype
549+
self._scale_dtype = scale_dtype
547550
self._use_fp4_custom_triton_dequant_kernel = (
548551
use_fp4_custom_triton_dequant_kernel
549552
)
@@ -589,20 +592,22 @@ def to_mx(
589592
data_hp: torch.Tensor,
590593
elem_dtype: Union[torch.dtype, str],
591594
block_size: int = BLOCK_SIZE_DEFAULT,
595+
scale_dtype: Optional[torch.dtype] = None,
592596
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
593597
use_fp4_custom_triton_dequant_kernel: bool = False,
594598
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
595599
pack_fp6: bool = False,
596600
):
597601
scale_e8m0_biased, data_lp = to_mx(
598-
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
602+
data_hp, elem_dtype, block_size, scale_dtype, scaling_mode, pack_fp6
599603
)
600604
return MXTensor(
601605
scale_e8m0_biased,
602606
data_lp,
603607
elem_dtype,
604608
block_size,
605609
data_hp.dtype,
610+
scale_dtype,
606611
use_fp4_custom_triton_dequant_kernel,
607612
gemm_kernel_choice,
608613
pack_fp6,
@@ -613,6 +618,7 @@ def __tensor_flatten__(self):
613618
"_elem_dtype": self._elem_dtype,
614619
"_block_size": self._block_size,
615620
"_orig_dtype": self._orig_dtype,
621+
"_scale_dtype": self._scale_dtype,
616622
"_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel,
617623
"_gemm_kernel_choice": self._gemm_kernel_choice,
618624
"_pack_fp6": self._pack_fp6,
@@ -632,6 +638,7 @@ def __tensor_unflatten__(
632638
metadata["_elem_dtype"],
633639
metadata["_block_size"],
634640
metadata["_orig_dtype"],
641+
metadata["_scale_dtype"],
635642
metadata["_use_fp4_custom_triton_dequant_kernel"],
636643
metadata["_gemm_kernel_choice"],
637644
metadata["_pack_fp6"],
@@ -664,6 +671,7 @@ def _same_metadata(cls, self: "MXTensor", src: "MXTensor") -> bool:
664671
and self._elem_dtype == src._elem_dtype
665672
and self._block_size == src._block_size
666673
and self._orig_dtype == src._orig_dtype
674+
and self._scale_dtype == src._scale_dtype
667675
and self._use_fp4_custom_triton_dequant_kernel
668676
== src._use_fp4_custom_triton_dequant_kernel
669677
and self._gemm_kernel_choice == src._gemm_kernel_choice

0 commit comments

Comments
 (0)