Skip to content

Commit 4fe3daf

Browse files
committed
NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 4e25496 commit 4fe3daf

File tree

6 files changed

+894
-14
lines changed

6 files changed

+894
-14
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch.nn.functional as F
1213

1314
from torchao.prototype.mx_formats.config import (
1415
MXGemmKernelChoice,
@@ -25,7 +26,11 @@
2526
MXInferenceLinear,
2627
MXLinear,
2728
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
29+
from torchao.prototype.mx_formats.mx_subclass import (
30+
MXFPInferenceConfig,
31+
NVFP4InferenceConfig,
32+
NVFP4MMConfig,
33+
)
2934
from torchao.quantization import quantize_
3035
from torchao.quantization.utils import compute_error
3136
from torchao.testing.utils import skip_if_rocm
@@ -404,6 +409,7 @@ def test_inference_print_str():
404409
@skip_if_rocm(
405410
"ROCm float4 gemm require gfx950"
406411
) # TODO(future): deploy gfx950 in ROCM CI
412+
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
407413
def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
408414
"""
409415
Smoke test for inference compile
@@ -441,3 +447,133 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441447
assert sqnr >= SQNR_THRESHOLD, (
442448
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443449
)
450+
451+
452+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
453+
@pytest.mark.skipif(
454+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
455+
)
456+
@pytest.mark.parametrize("bias", [True, False])
457+
@pytest.mark.parametrize("compile", [True, False])
458+
@pytest.mark.parametrize(
459+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
460+
)
461+
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
462+
@torch.no_grad()
463+
@skip_if_rocm("ROCm float4 gemm require gfx950")
464+
def test_inference_subclass_nvfp4(
465+
bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype
466+
):
467+
"""
468+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
469+
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
470+
"""
471+
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
472+
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
473+
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
474+
475+
if bias and inpt_dtype == torch.float32:
476+
pytest.xfail("Bias is not supported when module weight is in fp32")
477+
478+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
479+
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
480+
m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda")
481+
m_mx = copy.deepcopy(m)
482+
483+
config = NVFP4InferenceConfig(mm_config=mm_config)
484+
quantize_(m_mx, config=config)
485+
486+
if compile:
487+
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")
488+
489+
x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype)
490+
y_ref = m(x)
491+
y_mx = m_mx(x)
492+
sqnr = compute_error(y_ref, y_mx)
493+
494+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
495+
SQNR_THRESHOLD = 18.0
496+
else:
497+
SQNR_THRESHOLD = 15.0
498+
499+
assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
500+
assert sqnr >= SQNR_THRESHOLD, (
501+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
502+
)
503+
504+
505+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
506+
@pytest.mark.skipif(
507+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
508+
)
509+
@pytest.mark.parametrize("use_gelu", [True, False])
510+
@pytest.mark.parametrize(
511+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
512+
)
513+
@pytest.mark.parametrize("compile", [False])
514+
@pytest.mark.parametrize("bias", [True, False])
515+
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
516+
@torch.no_grad()
517+
@skip_if_rocm("ROCm float4 gemm require gfx950")
518+
def test_nvfp4_matmul_with_amax(
519+
use_gelu: bool,
520+
mm_config: NVFP4MMConfig,
521+
compile: bool,
522+
bias: bool,
523+
inpt_dtype: torch.dtype,
524+
):
525+
from torchao.prototype.mx_formats.nvfp4_tensor import (
526+
NVFP4Tensor,
527+
per_tensor_amax_to_scale,
528+
)
529+
530+
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
531+
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
532+
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
533+
534+
if bias and inpt_dtype == torch.float32:
535+
pytest.xfail("Bias is not supported when module weight is in fp32")
536+
537+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
538+
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
539+
540+
m, k, n = 64, 256, 128
541+
542+
# Create activation tensor
543+
if use_gelu:
544+
x = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
545+
A = torch.nn.functional.gelu(x)
546+
else:
547+
A = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
548+
549+
B = torch.randn(n, k, dtype=inpt_dtype, device="cuda")
550+
bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None
551+
552+
# Compute reference
553+
C_ref = F.linear(A, B, bias_tensor)
554+
555+
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
556+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
557+
A_nvfp4 = NVFP4Tensor.to_nvfp4(
558+
A,
559+
per_tensor_scale=a_scale,
560+
mm_config=mm_config,
561+
)
562+
B_nvfp4 = NVFP4Tensor.to_nvfp4(
563+
B,
564+
per_tensor_scale=b_scale,
565+
mm_config=mm_config,
566+
)
567+
568+
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear
569+
570+
C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor)
571+
assert C_nvfp4.dtype == inpt_dtype, (
572+
f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}"
573+
)
574+
575+
sqnr = compute_error(C_ref, C_nvfp4)
576+
SQNR_THRESHOLD = 16.0
577+
assert sqnr >= SQNR_THRESHOLD, (
578+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
579+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP6_E2M3,
1616
DTYPE_FP6_E3M2,
17+
F4_E2M1_MAX,
1718
SUPPORTED_ELEM_DTYPES,
1819
)
1920
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
@@ -591,3 +592,68 @@ def to_f8(x):
591592
torch.testing.assert_close(
592593
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
593594
)
595+
596+
597+
@pytest.mark.parametrize(
598+
"dtype,shape,use_per_tensor_scale",
599+
[
600+
(torch.bfloat16, (32, 64), False),
601+
(torch.float32, (64, 128), False),
602+
(torch.bfloat16, (128, 256), False),
603+
(torch.bfloat16, (64, 128), True),
604+
],
605+
)
606+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
607+
@pytest.mark.skipif(
608+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
609+
)
610+
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
611+
from torchao.prototype.mx_formats.nvfp4_tensor import (
612+
NVFP4Tensor,
613+
per_tensor_amax_to_scale,
614+
)
615+
616+
x = torch.randn(shape, dtype=dtype, device="cuda")
617+
if use_per_tensor_scale:
618+
tensor_amax = torch.max(torch.abs(x))
619+
scale = per_tensor_amax_to_scale(tensor_amax)
620+
else:
621+
scale = None
622+
623+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
624+
x_reconstructed = x_nvfp4.to_dtype(dtype)
625+
626+
def assert_sqnr_gt_threshold(orig, new, threshold):
627+
sqnr = compute_error(orig, new)
628+
if torch.all(torch.isnan(sqnr)):
629+
# if both operands are full of zeroes, sqnr is nan and this is ok
630+
# test for this explicitly
631+
assert torch.all(orig == 0) and torch.all(new == 0)
632+
else:
633+
assert sqnr >= threshold
634+
635+
reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
636+
max_abs = torch.amax(
637+
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
638+
).unsqueeze(-1)
639+
640+
assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
641+
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)
642+
643+
assert x.shape == x_reconstructed.shape, (
644+
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
645+
)
646+
assert x.dtype == x_reconstructed.dtype, (
647+
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
648+
)
649+
650+
x_nvfp4_t = x_nvfp4.t()
651+
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
652+
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
653+
654+
assert x.t().shape == x_reconstructed_t.shape, (
655+
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
656+
)
657+
assert x.t().dtype == x_reconstructed_t.dtype, (
658+
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
659+
)

torchao/prototype/mx_formats/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
)
77

88
# Note: Prototype and subject to change
9-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
9+
from torchao.prototype.mx_formats.mx_subclass import (
10+
MXFPInferenceConfig,
11+
NVFP4InferenceConfig,
12+
NVFP4MMConfig,
13+
)
1014

1115
# import mx_linear here to register the quantize_ transform logic
1216
# ruff: noqa: I001
@@ -18,4 +22,6 @@
1822
"MXLinearConfig",
1923
"MXLinearRecipeName",
2024
"MXFPInferenceConfig",
25+
"NVFP4InferenceConfig",
26+
"NVFP4MMConfig",
2127
]

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
import torchao
1413
from torchao.core.config import AOBaseConfig
1514
from torchao.prototype.mx_formats import (
1615
MXGemmKernelChoice,
@@ -20,13 +19,19 @@
2019
_validate_gemm_kernel_choice,
2120
)
2221
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
2323
from torchao.quantization.quant_api import to_linear_activation_quantized
2424
from torchao.quantization.transform_module import (
2525
register_quantize_module_handler,
2626
)
27-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
27+
from torchao.utils import (
28+
TORCH_VERSION_AT_LEAST_2_5,
29+
TORCH_VERSION_AT_LEAST_2_8,
30+
is_sm_at_least_100,
31+
)
2832

2933

34+
# TODO The naming for these configs is a little weird, rename before moving to public API
3035
# Note: This API is extra prototype and will change in the future
3136
@dataclass
3237
class MXFPInferenceConfig(AOBaseConfig):
@@ -63,16 +68,13 @@ class MXFPInferenceConfig(AOBaseConfig):
6368

6469
block_size: int = 32
6570

66-
# Dtypes for Input and Weights
71+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6772
activation_dtype: torch.dtype = torch.float8_e4m3fn
6873
weight_dtype: torch.dtype = torch.float8_e4m3fn
6974

7075
# Which kernel to run for mm
7176
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
7277

73-
# Set some magic perf settings
74-
set_inductor_config: bool = False
75-
7678
def __post_init__(self):
7779
assert self.activation_dtype == self.weight_dtype, (
7880
"For now - we only support matching input/weight dtypes."
@@ -115,8 +117,6 @@ def _mx_inference_linear_transform(
115117
# TODO Sm120 has slightly more restrictive reqs
116118
# TODO handle AMD
117119
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
118-
if config.set_inductor_config:
119-
torchao.quantization.utils.recommended_inductor_config_setter()
120120

121121
activation_dtype = config.activation_dtype
122122
weight_dtype = config.weight_dtype
@@ -151,7 +151,62 @@ def _mx_inference_linear_transform(
151151
return module
152152

153153

154+
@dataclass
155+
class NVFP4InferenceConfig(AOBaseConfig):
156+
"""
157+
NVIDIA FP4 (NVFP4) Inference Quantization Configuration
158+
159+
This is a specialized configuration for NVIDIA's FP4 format.
160+
All parameters are fixed in the NVFP4 implementation except mm_config:
161+
- mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision)
162+
- Data: float4_e2m1fn_x2
163+
- Scales: float8_e4m3fn
164+
- Block size: 16 along the reduction dim
165+
"""
166+
167+
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC
168+
169+
def __post_init__(self):
170+
# Validate PyTorch version
171+
if not TORCH_VERSION_AT_LEAST_2_8:
172+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
173+
174+
175+
@register_quantize_module_handler(NVFP4InferenceConfig)
176+
def _nvfp4_inference_linear_transform(
177+
module: torch.nn.Linear, config: NVFP4InferenceConfig
178+
):
179+
"""Quantization handler for NVFP4InferenceConfig"""
180+
if config.mm_config == NVFP4MMConfig.DYNAMIC:
181+
assert is_sm_at_least_100(), (
182+
"NVFP4 DYNAMIC mode is only supported on sm100+ machines"
183+
)
184+
185+
weight = module.weight
186+
187+
if module.bias is not None and weight.dtype == torch.float32:
188+
raise RuntimeError(
189+
"Bias is not supported when module weight is in fp32 (out_dtype=Float32). "
190+
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
191+
)
192+
193+
quantized_weight = NVFP4Tensor.to_nvfp4(
194+
weight,
195+
mm_config=config.mm_config,
196+
)
197+
198+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
199+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
200+
return module
201+
202+
154203
if TORCH_VERSION_AT_LEAST_2_5:
155204
torch.serialization.add_safe_globals(
156-
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
205+
[
206+
MXTensor,
207+
NVFP4Tensor,
208+
NVFP4MMConfig,
209+
MXGemmKernelChoice,
210+
_input_activation_quant_func_mxfp,
211+
]
157212
)

0 commit comments

Comments
 (0)