Skip to content

Commit 47bcbb8

Browse files
committed
NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 2898903 commit 47bcbb8

File tree

6 files changed

+885
-15
lines changed

6 files changed

+885
-15
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch.nn.functional as F
1213

1314
from torchao.prototype.mx_formats.config import (
1415
MXGemmKernelChoice,
@@ -25,7 +26,11 @@
2526
MXInferenceLinear,
2627
MXLinear,
2728
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
29+
from torchao.prototype.mx_formats.mx_subclass import (
30+
MXFPInferenceConfig,
31+
NVFP4InferenceConfig,
32+
NVFP4MMConfig,
33+
)
2934
from torchao.quantization import quantize_
3035
from torchao.quantization.utils import compute_error
3136
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +446,106 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441446
assert sqnr >= SQNR_THRESHOLD, (
442447
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443448
)
449+
450+
451+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
452+
@pytest.mark.skipif(
453+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
454+
)
455+
@pytest.mark.skipif(
456+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
457+
)
458+
@pytest.mark.parametrize("bias", [True, False])
459+
@pytest.mark.parametrize("compile", [True, False])
460+
@pytest.mark.parametrize(
461+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
462+
)
463+
@torch.no_grad()
464+
@skip_if_rocm("ROCm float4 gemm require gfx950")
465+
def test_inference_subclass_nvfp4(bias: bool, compile: bool, mm_config: NVFP4MMConfig):
466+
"""
467+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
468+
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
469+
"""
470+
m = nn.Linear(64, 256, bias=bias, dtype=torch.bfloat16, device="cuda")
471+
m_mx = copy.deepcopy(m)
472+
473+
config = NVFP4InferenceConfig(mm_config=mm_config)
474+
quantize_(m_mx, config=config)
475+
if compile:
476+
m_mx = torch.compile(m_mx, fullgraph=True)
477+
478+
x = torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)
479+
y_ref = m(x)
480+
y_mx = m_mx(x)
481+
sqnr = compute_error(y_ref, y_mx)
482+
483+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
484+
SQNR_THRESHOLD = 18.0
485+
else:
486+
SQNR_THRESHOLD = 15.0
487+
488+
assert sqnr >= SQNR_THRESHOLD, (
489+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
490+
)
491+
492+
493+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
494+
@pytest.mark.skipif(
495+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
496+
)
497+
@pytest.mark.skipif(
498+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
499+
)
500+
@pytest.mark.parametrize("use_gelu", [True, False])
501+
@pytest.mark.parametrize(
502+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
503+
)
504+
@pytest.mark.parametrize("compile", [False])
505+
@pytest.mark.parametrize("bias", [True, False])
506+
@torch.no_grad()
507+
@skip_if_rocm("ROCm float4 gemm require gfx950")
508+
def test_nvfp4_matmul_with_amax(
509+
use_gelu: bool, mm_config: NVFP4MMConfig, compile: bool, bias: bool
510+
):
511+
from torchao.prototype.mx_formats.nvfp4_tensor import (
512+
NVFP4Tensor,
513+
per_tensor_amax_to_scale,
514+
)
515+
516+
m, k, n = 64, 256, 128
517+
518+
# Create activation tensor
519+
if use_gelu:
520+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
521+
A = torch.nn.functional.gelu(x)
522+
else:
523+
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
524+
525+
B = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
526+
bias_tensor = torch.randn(n, dtype=torch.bfloat16, device="cuda") if bias else None
527+
528+
# Compute reference
529+
C_ref = F.linear(A, B, bias_tensor)
530+
531+
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
532+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
533+
A_nvfp4 = NVFP4Tensor.to_nvfp4(
534+
A,
535+
per_tensor_scale=a_scale,
536+
mm_config=mm_config,
537+
)
538+
B_nvfp4 = NVFP4Tensor.to_nvfp4(
539+
B,
540+
per_tensor_scale=b_scale,
541+
mm_config=mm_config,
542+
)
543+
544+
C_nvfp4 = F.linear(A_nvfp4, B_nvfp4, bias_tensor)
545+
sqnr = compute_error(C_ref, C_nvfp4)
546+
547+
# Check quality threshold
548+
SQNR_THRESHOLD = 16.0
549+
assert sqnr >= SQNR_THRESHOLD, (
550+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
551+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP6_E2M3,
1616
DTYPE_FP6_E3M2,
17+
F4_E2M1_MAX,
1718
SUPPORTED_ELEM_DTYPES,
1819
)
1920
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
@@ -591,3 +592,68 @@ def to_f8(x):
591592
torch.testing.assert_close(
592593
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
593594
)
595+
596+
597+
@pytest.mark.parametrize(
598+
"dtype,shape,use_per_tensor_scale",
599+
[
600+
(torch.bfloat16, (32, 64), False),
601+
(torch.float32, (64, 128), False),
602+
(torch.bfloat16, (128, 256), False),
603+
(torch.bfloat16, (64, 128), True),
604+
],
605+
)
606+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
607+
@pytest.mark.skipif(
608+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
609+
)
610+
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
611+
from torchao.prototype.mx_formats.nvfp4_tensor import (
612+
NVFP4Tensor,
613+
per_tensor_amax_to_scale,
614+
)
615+
616+
x = torch.randn(shape, dtype=dtype, device="cuda")
617+
if use_per_tensor_scale:
618+
tensor_amax = torch.max(torch.abs(x))
619+
scale = per_tensor_amax_to_scale(tensor_amax)
620+
else:
621+
scale = None
622+
623+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
624+
x_reconstructed = x_nvfp4.to_dtype(dtype)
625+
626+
def assert_sqnr_gt_threshold(orig, new, threshold):
627+
sqnr = compute_error(orig, new)
628+
if torch.all(torch.isnan(sqnr)):
629+
# if both operands are full of zeroes, sqnr is nan and this is ok
630+
# test for this explicitly
631+
assert torch.all(orig == 0) and torch.all(new == 0)
632+
else:
633+
assert sqnr >= threshold
634+
635+
reconstructed_amax = x_nvfp4.get_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
636+
max_abs = torch.amax(
637+
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
638+
).unsqueeze(-1)
639+
640+
assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
641+
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)
642+
643+
assert x.shape == x_reconstructed.shape, (
644+
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
645+
)
646+
assert x.dtype == x_reconstructed.dtype, (
647+
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
648+
)
649+
650+
x_nvfp4_t = x_nvfp4.t()
651+
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
652+
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
653+
654+
assert x.t().shape == x_reconstructed_t.shape, (
655+
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
656+
)
657+
assert x.t().dtype == x_reconstructed_t.dtype, (
658+
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
659+
)

torchao/prototype/mx_formats/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
)
77

88
# Note: Prototype and subject to change
9-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
9+
from torchao.prototype.mx_formats.mx_subclass import (
10+
MXFPInferenceConfig,
11+
NVFP4InferenceConfig,
12+
)
1013

1114
# import mx_linear here to register the quantize_ transform logic
1215
# ruff: noqa: I001
@@ -18,4 +21,5 @@
1821
"MXLinearConfig",
1922
"MXLinearRecipeName",
2023
"MXFPInferenceConfig",
24+
"NVFP4InferenceConfig",
2125
]

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import types
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from typing import Optional
1010

1111
import torch
1212

13-
import torchao
1413
from torchao.core.config import AOBaseConfig
1514
from torchao.prototype.mx_formats import (
1615
MXGemmKernelChoice,
@@ -20,11 +19,16 @@
2019
_validate_gemm_kernel_choice,
2120
)
2221
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
2323
from torchao.quantization.quant_api import to_linear_activation_quantized
2424
from torchao.quantization.transform_module import (
2525
register_quantize_module_handler,
2626
)
27-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
27+
from torchao.utils import (
28+
TORCH_VERSION_AT_LEAST_2_5,
29+
TORCH_VERSION_AT_LEAST_2_8,
30+
is_sm_at_least_100,
31+
)
2832

2933

3034
# Note: This API is extra prototype and will change in the future
@@ -63,16 +67,13 @@ class MXFPInferenceConfig(AOBaseConfig):
6367

6468
block_size: int = 32
6569

66-
# Dtypes for Input and Weights
70+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6771
activation_dtype: torch.dtype = torch.float8_e4m3fn
6872
weight_dtype: torch.dtype = torch.float8_e4m3fn
6973

7074
# Which kernel to run for mm
7175
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
7276

73-
# Set some magic perf settings
74-
set_inductor_config: bool = False
75-
7677
def __post_init__(self):
7778
assert self.activation_dtype == self.weight_dtype, (
7879
"For now - we only support matching input/weight dtypes."
@@ -115,8 +116,6 @@ def _mx_inference_linear_transform(
115116
# TODO Sm120 has slightly more restrictive reqs
116117
# TODO handle AMD
117118
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
118-
if config.set_inductor_config:
119-
torchao.quantization.utils.recommended_inductor_config_setter()
120119

121120
activation_dtype = config.activation_dtype
122121
weight_dtype = config.weight_dtype
@@ -151,7 +150,90 @@ def _mx_inference_linear_transform(
151150
return module
152151

153152

153+
def _get_nvfp4_dtype():
154+
"""Factory function for NVFP4 dtype defaults."""
155+
if not TORCH_VERSION_AT_LEAST_2_8:
156+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
157+
return torch.float4_e2m1fn_x2
158+
159+
160+
@dataclass
161+
class NVFP4InferenceConfig(AOBaseConfig):
162+
"""
163+
NVIDIA FP4 (NVFP4) Inference Quantization Configuration
164+
165+
This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.
166+
It provides defaults optimized for NVFP4:
167+
- Data: float4_e2m1fn_x2
168+
- Scales: float8_e4m3fn (UE4M3)
169+
- Block size: 16 (required for NVFP4)
170+
- CUBLAS kernel (optimized for VEC16_UE4M3)
171+
"""
172+
173+
block_size: int = 16 # NVFP4 requires block size 16
174+
175+
# NVFP4 uses FP4 data
176+
activation_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
177+
weight_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
178+
179+
# NVFP4 uses E4M3 scales
180+
scale_dtype: torch.dtype = torch.float8_e4m3fn
181+
182+
# CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support
183+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
184+
185+
# Matrix multiplication configuration
186+
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC
187+
188+
def __post_init__(self):
189+
# Validate NVFP4 constraints
190+
if not TORCH_VERSION_AT_LEAST_2_8:
191+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
192+
193+
assert self.activation_dtype == torch.float4_e2m1fn_x2, (
194+
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"
195+
)
196+
assert self.weight_dtype == torch.float4_e2m1fn_x2, (
197+
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"
198+
)
199+
assert self.scale_dtype == torch.float8_e4m3fn, (
200+
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"
201+
)
202+
assert self.block_size == 16, (
203+
f"NVFP4 requires block_size=16, got {self.block_size}"
204+
)
205+
206+
207+
@register_quantize_module_handler(NVFP4InferenceConfig)
208+
def _nvfp4_inference_linear_transform(
209+
module: torch.nn.Module, config: NVFP4InferenceConfig
210+
):
211+
"""Quantization handler for NVFP4InferenceConfig"""
212+
assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"
213+
214+
weight = module.weight
215+
assert weight.dtype == torch.bfloat16, (
216+
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
217+
)
218+
219+
quantized_weight = NVFP4Tensor.to_nvfp4(
220+
weight,
221+
block_size=config.block_size,
222+
mm_config=config.mm_config,
223+
)
224+
225+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
226+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
227+
return module
228+
229+
154230
if TORCH_VERSION_AT_LEAST_2_5:
155231
torch.serialization.add_safe_globals(
156-
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
232+
[
233+
MXTensor,
234+
NVFP4Tensor,
235+
NVFP4MMConfig,
236+
MXGemmKernelChoice,
237+
_input_activation_quant_func_mxfp,
238+
]
157239
)

0 commit comments

Comments
 (0)