Skip to content

Commit 7448f45

Browse files
committed
WIP NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 101c039 commit 7448f45

File tree

5 files changed

+434
-6
lines changed

5 files changed

+434
-6
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
28+
from torchao.prototype.mx_formats.mx_subclass import (
29+
MXFPInferenceConfig,
30+
NVFP4InferenceConfig,
31+
)
2932
from torchao.quantization import quantize_
3033
from torchao.quantization.utils import compute_error
3134
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +444,36 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441444
assert sqnr >= SQNR_THRESHOLD, (
442445
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443446
)
447+
448+
449+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450+
@pytest.mark.skipif(
451+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
452+
)
453+
@pytest.mark.skipif(
454+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
455+
)
456+
@pytest.mark.parametrize("bias", [True, False])
457+
@pytest.mark.parametrize("compile", [True, False])
458+
@torch.no_grad()
459+
@skip_if_rocm("ROCm float4 gemm require gfx950")
460+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
461+
"""
462+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
463+
"""
464+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
465+
m_mx = copy.deepcopy(m)
466+
467+
config = NVFP4InferenceConfig()
468+
quantize_(m_mx, config=config)
469+
if compile:
470+
m_mx = torch.compile(m_mx, fullgraph=True)
471+
472+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
473+
y_ref = m(x)
474+
y_mx = m_mx(x)
475+
sqnr = compute_error(y_ref, y_mx)
476+
SQNR_THRESHOLD = 15.0 # Float4 threshold
477+
assert sqnr >= SQNR_THRESHOLD, (
478+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
479+
)

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
104104
w_elem_dtype,
105105
block_size,
106106
weight_hp.dtype,
107+
None, # scale_dtype
107108
False,
108109
gemm_kernel_choice,
109110
False,
@@ -133,6 +134,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
133134
grad_elem_dtype,
134135
block_size,
135136
grad_output_hp_r.dtype,
137+
None, # scale_dtype
136138
False,
137139
gemm_kernel_choice,
138140
False,
@@ -155,6 +157,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
155157
in_elem_dtype,
156158
block_size,
157159
input_hp_r.dtype,
160+
None, # scale_dtype
158161
False,
159162
gemm_kernel_choice,
160163
False,

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_validate_gemm_kernel_choice,
2121
)
2222
from torchao.prototype.mx_formats.mx_tensor import MXTensor
23+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
2324
from torchao.quantization.quant_api import to_linear_activation_quantized
2425
from torchao.quantization.transform_module import (
2526
register_quantize_module_handler,
@@ -63,7 +64,7 @@ class MXFPInferenceConfig(AOBaseConfig):
6364

6465
block_size: int = 32
6566

66-
# Dtypes for Input and Weights
67+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6768
activation_dtype: torch.dtype = torch.float8_e4m3fn
6869
weight_dtype: torch.dtype = torch.float8_e4m3fn
6970

@@ -151,7 +152,106 @@ def _mx_inference_linear_transform(
151152
return module
152153

153154

155+
@dataclass
156+
class NVFP4InferenceConfig(AOBaseConfig):
157+
"""
158+
NVIDIA FP4 (NVFP4) Inference Quantization Configuration
159+
160+
This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.
161+
It provides defaults optimized for NVFP4:
162+
- Data: float4_e2m1fn_x2
163+
- Scales: float8_e4m3fn (UE4M3)
164+
- Block size: 16 (required for NVFP4)
165+
- CUBLAS kernel (optimized for VEC16_UE4M3)
166+
"""
167+
168+
block_size: int = 16 # NVFP4 requires block size 16
169+
170+
# NVFP4 uses FP4 data
171+
activation_dtype: torch.dtype = torch.float4_e2m1fn_x2
172+
weight_dtype: torch.dtype = torch.float4_e2m1fn_x2
173+
174+
# NVFP4 uses E4M3 scales
175+
scale_dtype: torch.dtype = torch.float8_e4m3fn
176+
177+
# CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support
178+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
179+
180+
# Set some magic perf settings
181+
set_inductor_config: bool = False
182+
183+
def __post_init__(self):
184+
# Validate NVFP4 constraints
185+
assert self.activation_dtype == torch.float4_e2m1fn_x2, (
186+
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"
187+
)
188+
assert self.weight_dtype == torch.float4_e2m1fn_x2, (
189+
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"
190+
)
191+
assert self.scale_dtype == torch.float8_e4m3fn, (
192+
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"
193+
)
194+
assert self.block_size == 16, (
195+
f"NVFP4 requires block_size=16, got {self.block_size}"
196+
)
197+
198+
199+
def _input_activation_quant_func_nvfp4(
200+
x: torch.Tensor,
201+
block_size: int = 16,
202+
scale: Optional[torch.Tensor] = None,
203+
):
204+
"""NVFP4-specific activation quantization function"""
205+
# TODO: scale for static quant
206+
activation = NVFP4Tensor.to_nvfp4(
207+
x,
208+
block_size=block_size,
209+
)
210+
return activation
211+
212+
213+
@register_quantize_module_handler(NVFP4InferenceConfig)
214+
def _nvfp4_inference_linear_transform(
215+
module: torch.nn.Module, config: NVFP4InferenceConfig
216+
):
217+
"""Quantization handler for NVFP4InferenceConfig"""
218+
assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"
219+
if config.set_inductor_config:
220+
torchao.quantization.utils.recommended_inductor_config_setter()
221+
222+
weight = module.weight
223+
assert weight.dtype == torch.bfloat16, (
224+
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
225+
)
226+
227+
# Convert weight to NVFP4 Tensor
228+
quantized_weight = NVFP4Tensor.to_nvfp4(
229+
weight,
230+
block_size=config.block_size,
231+
)
232+
233+
input_quant_func = _input_activation_quant_func_nvfp4
234+
input_quant_kwargs = {
235+
"block_size": config.block_size,
236+
"scale": None,
237+
}
238+
239+
quantized_weight = to_linear_activation_quantized(
240+
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
241+
)
242+
243+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
244+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
245+
return module
246+
247+
154248
if TORCH_VERSION_AT_LEAST_2_5:
155249
torch.serialization.add_safe_globals(
156-
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
250+
[
251+
MXTensor,
252+
NVFP4Tensor,
253+
MXGemmKernelChoice,
254+
_input_activation_quant_func_mxfp,
255+
_input_activation_quant_func_nvfp4,
256+
]
157257
)

0 commit comments

Comments
 (0)