Skip to content

Commit c58c5b0

Browse files
committed
WIP NVfp4
1 parent 101c039 commit c58c5b0

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,42 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441441
assert sqnr >= SQNR_THRESHOLD, (
442442
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443443
)
444+
445+
446+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
447+
@pytest.mark.skipif(
448+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
449+
)
450+
@pytest.mark.skipif(
451+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
452+
)
453+
@pytest.mark.parametrize("bias", [True, False])
454+
@pytest.mark.parametrize("compile", [True, False])
455+
@torch.no_grad()
456+
@skip_if_rocm("ROCm float4 gemm require gfx950")
457+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
458+
"""
459+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
460+
"""
461+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
462+
m_mx = copy.deepcopy(m)
463+
464+
config = MXFPInferenceConfig(
465+
activation_dtype=torch.float4_e2m1fn_x2,
466+
weight_dtype=torch.float4_e2m1fn_x2,
467+
scale_dtype=torch.float8_e4m3fn, # NVFP4 scale dtype
468+
block_size=16, # NVFP4 block size
469+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
470+
)
471+
quantize_(m_mx, config=config)
472+
if compile:
473+
m_mx = torch.compile(m_mx, fullgraph=True)
474+
475+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
476+
y_ref = m(x)
477+
y_mx = m_mx(x)
478+
sqnr = compute_error(y_ref, y_mx)
479+
SQNR_THRESHOLD = 15.0 # Float4 threshold
480+
assert sqnr >= SQNR_THRESHOLD, (
481+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
482+
)

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import types
88
from dataclasses import dataclass
9-
from typing import Optional
9+
from typing import Literal, Optional, Union
1010

1111
import torch
1212

@@ -27,6 +27,30 @@
2727
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
2828

2929

30+
def _validate_scale_dtype(
31+
block_size: int,
32+
weight_dtype: torch.dtype,
33+
activation_dtype: torch.dtype,
34+
scale_dtype: torch.dtype,
35+
):
36+
"""Validate that the scale dtype is one of the supported float8 types."""
37+
assert scale_dtype in [
38+
torch.float8_e8m0fnu,
39+
torch.float8_e4m3fn,
40+
], f"Unsupported scale_dtype {scale_dtype}, must be float8_e8m0fnu or float8_e4m3fn"
41+
if scale_dtype == torch.float8_e8m0fnu:
42+
_validate_elem_dtype(weight_dtype)
43+
_validate_elem_dtype(activation_dtype)
44+
return
45+
46+
assert (
47+
weight_dtype == activation_dtype and weight_dtype == torch.float4_e2m1fn_x2
48+
), (
49+
f"scale_dtype {scale_dtype} is only supported with weight_dtype {weight_dtype} and activation_dtype {activation_dtype}, got weight_dtype {weight_dtype} and activation_dtype {activation_dtype}"
50+
)
51+
assert block_size == 16, f"For NVFP4, block_size must be 16, got {block_size}"
52+
53+
3054
# Note: This API is extra prototype and will change in the future
3155
@dataclass
3256
class MXFPInferenceConfig(AOBaseConfig):
@@ -61,12 +85,16 @@ class MXFPInferenceConfig(AOBaseConfig):
6185
- MXTensor in torchao.prototype.mx_formats.mx_tensor
6286
"""
6387

64-
block_size: int = 32
88+
block_size: Union[Literal[32], Literal[16]] = 32
6589

66-
# Dtypes for Input and Weights
90+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6791
activation_dtype: torch.dtype = torch.float8_e4m3fn
6892
weight_dtype: torch.dtype = torch.float8_e4m3fn
6993

94+
# Supports float8_e4m3fn, float8_e8m0fnu
95+
# e8m0 for MX and e4m3 for NVFP4 on Cuda compatable devices
96+
scale_dtype: torch.dtype = torch.float8_e8m0fnu
97+
7098
# Which kernel to run for mm
7199
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
72100

@@ -82,6 +110,9 @@ def __post_init__(self):
82110
_validate_gemm_kernel_choice(
83111
self.gemm_kernel_choice, self.block_size, self.weight_dtype
84112
)
113+
_validate_scale_dtype(
114+
self.block_size, self.weight_dtype, self.activation_dtype, self.scale_dtype
115+
)
85116

86117

87118
def _linear_extra_repr(self):

0 commit comments

Comments
 (0)