Skip to content

Commit 788e593

Browse files
committed
NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 2898903 commit 788e593

File tree

6 files changed

+872
-15
lines changed

6 files changed

+872
-15
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
28+
from torchao.prototype.mx_formats.mx_subclass import (
29+
MXFPInferenceConfig,
30+
NVFP4InferenceConfig,
31+
)
2932
from torchao.quantization import quantize_
3033
from torchao.quantization.utils import compute_error
3134
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +444,110 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441444
assert sqnr >= SQNR_THRESHOLD, (
442445
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443446
)
447+
448+
449+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450+
@pytest.mark.skipif(
451+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
452+
)
453+
@pytest.mark.skipif(
454+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
455+
)
456+
@pytest.mark.parametrize("bias", [True, False])
457+
@pytest.mark.parametrize("compile", [True, False])
458+
@torch.no_grad()
459+
@skip_if_rocm("ROCm float4 gemm require gfx950")
460+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
461+
"""
462+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
463+
"""
464+
m = nn.Linear(64, 256, bias=bias, dtype=torch.bfloat16, device="cuda")
465+
m_mx = copy.deepcopy(m)
466+
467+
config = NVFP4InferenceConfig()
468+
quantize_(m_mx, config=config)
469+
if compile:
470+
m_mx = torch.compile(m_mx, fullgraph=True)
471+
472+
x = torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)
473+
y_ref = m(x)
474+
y_mx = m_mx(x)
475+
sqnr = compute_error(y_ref, y_mx)
476+
SQNR_THRESHOLD = 15.0
477+
assert sqnr >= SQNR_THRESHOLD, (
478+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
479+
)
480+
481+
482+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
483+
@pytest.mark.skipif(
484+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
485+
)
486+
@pytest.mark.skipif(
487+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
488+
)
489+
@pytest.mark.parametrize("use_gelu", [True, False])
490+
@pytest.mark.parametrize("emulate", [True, False])
491+
@pytest.mark.parametrize("compile", [False])
492+
@pytest.mark.parametrize("bias", [True, False])
493+
@torch.no_grad()
494+
@skip_if_rocm("ROCm float4 gemm require gfx950")
495+
def test_nvfp4_matmul_with_amax(
496+
use_gelu: bool, emulate: bool, compile: bool, bias: bool
497+
):
498+
from torchao.prototype.mx_formats.nvfp4_tensor import (
499+
NVFP4Tensor,
500+
per_tensor_amax_to_scale,
501+
)
502+
503+
m, k, n = 64, 256, 128
504+
505+
# Create activation tensor
506+
if use_gelu:
507+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
508+
A = torch.nn.functional.gelu(x)
509+
else:
510+
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
511+
512+
B = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
513+
bias_tensor = torch.randn(n, dtype=torch.bfloat16, device="cuda") if bias else None
514+
515+
# Compute reference
516+
C_ref = torch.matmul(A, B.t())
517+
if bias:
518+
C_ref = C_ref + bias_tensor
519+
520+
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
521+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
522+
A_nvfp4 = NVFP4Tensor.to_nvfp4(A, per_tensor_scale=a_scale)
523+
B_nvfp4 = NVFP4Tensor.to_nvfp4(B, per_tensor_scale=b_scale)
524+
525+
if emulate:
526+
# Cast back to original dtype and compute
527+
A_emulated = A_nvfp4.to_dtype(A.dtype)
528+
B_emulated = B_nvfp4.to_dtype(B.dtype)
529+
mm = torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
530+
C_emulated = mm(A_emulated, B_emulated.t())
531+
if bias:
532+
C_emulated = C_emulated + bias_tensor
533+
sqnr = compute_error(C_ref, C_emulated)
534+
else:
535+
if bias:
536+
linear_fn = (
537+
torch.compile(torch.nn.functional.linear, fullgraph=True)
538+
if compile
539+
else torch.nn.functional.linear
540+
)
541+
C_nvfp4 = linear_fn(A_nvfp4, B_nvfp4, bias_tensor)
542+
else:
543+
mm = (
544+
torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
545+
)
546+
C_nvfp4 = mm(A_nvfp4, B_nvfp4.t())
547+
sqnr = compute_error(C_ref, C_nvfp4)
548+
549+
# Check quality threshold
550+
SQNR_THRESHOLD = 16.0
551+
assert sqnr >= SQNR_THRESHOLD, (
552+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, emulate={emulate}, compile={compile}, bias={bias}"
553+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP6_E2M3,
1616
DTYPE_FP6_E3M2,
17+
F4_E2M1_MAX,
1718
SUPPORTED_ELEM_DTYPES,
1819
)
1920
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
@@ -591,3 +592,57 @@ def to_f8(x):
591592
torch.testing.assert_close(
592593
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
593594
)
595+
596+
597+
@pytest.mark.parametrize(
598+
"dtype,shape,use_per_tensor_scale",
599+
[
600+
(torch.bfloat16, (32, 64), False),
601+
(torch.float32, (64, 128), False),
602+
(torch.bfloat16, (128, 256), False),
603+
(torch.bfloat16, (64, 128), True),
604+
],
605+
)
606+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
607+
@pytest.mark.skipif(
608+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
609+
)
610+
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
611+
from torchao.prototype.mx_formats.nvfp4_tensor import (
612+
NVFP4Tensor,
613+
per_tensor_amax_to_scale,
614+
)
615+
616+
x = torch.randn(shape, dtype=dtype, device="cuda")
617+
if use_per_tensor_scale:
618+
tensor_amax = torch.max(torch.abs(x))
619+
scale = per_tensor_amax_to_scale(tensor_amax)
620+
else:
621+
scale = None
622+
623+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
624+
x_reconstructed = x_nvfp4.to_dtype(dtype)
625+
626+
def assert_sqnr_gt_threshold(orig, new, threshold):
627+
sqnr = compute_error(orig, new)
628+
if torch.all(torch.isnan(sqnr)):
629+
# if both operands are full of zeroes, sqnr is nan and this is ok
630+
# test for this explicitly
631+
assert torch.all(orig == 0) and torch.all(new == 0)
632+
else:
633+
assert sqnr >= threshold
634+
635+
reconstructed_amax = x_nvfp4.get_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
636+
max_abs = torch.amax(
637+
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
638+
).unsqueeze(-1)
639+
640+
assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
641+
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)
642+
643+
assert x.shape == x_reconstructed.shape, (
644+
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
645+
)
646+
assert x.dtype == x_reconstructed.dtype, (
647+
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
648+
)

torchao/prototype/mx_formats/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
)
77

88
# Note: Prototype and subject to change
9-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
9+
from torchao.prototype.mx_formats.mx_subclass import (
10+
MXFPInferenceConfig,
11+
NVFP4InferenceConfig,
12+
)
1013

1114
# import mx_linear here to register the quantize_ transform logic
1215
# ruff: noqa: I001
@@ -18,4 +21,5 @@
1821
"MXLinearConfig",
1922
"MXLinearRecipeName",
2023
"MXFPInferenceConfig",
24+
"NVFP4InferenceConfig",
2125
]

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import types
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from typing import Optional
1010

1111
import torch
1212

13-
import torchao
1413
from torchao.core.config import AOBaseConfig
1514
from torchao.prototype.mx_formats import (
1615
MXGemmKernelChoice,
@@ -20,11 +19,16 @@
2019
_validate_gemm_kernel_choice,
2120
)
2221
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
2323
from torchao.quantization.quant_api import to_linear_activation_quantized
2424
from torchao.quantization.transform_module import (
2525
register_quantize_module_handler,
2626
)
27-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
27+
from torchao.utils import (
28+
TORCH_VERSION_AT_LEAST_2_5,
29+
TORCH_VERSION_AT_LEAST_2_8,
30+
is_sm_at_least_100,
31+
)
2832

2933

3034
# Note: This API is extra prototype and will change in the future
@@ -63,16 +67,13 @@ class MXFPInferenceConfig(AOBaseConfig):
6367

6468
block_size: int = 32
6569

66-
# Dtypes for Input and Weights
70+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6771
activation_dtype: torch.dtype = torch.float8_e4m3fn
6872
weight_dtype: torch.dtype = torch.float8_e4m3fn
6973

7074
# Which kernel to run for mm
7175
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
7276

73-
# Set some magic perf settings
74-
set_inductor_config: bool = False
75-
7677
def __post_init__(self):
7778
assert self.activation_dtype == self.weight_dtype, (
7879
"For now - we only support matching input/weight dtypes."
@@ -115,8 +116,6 @@ def _mx_inference_linear_transform(
115116
# TODO Sm120 has slightly more restrictive reqs
116117
# TODO handle AMD
117118
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
118-
if config.set_inductor_config:
119-
torchao.quantization.utils.recommended_inductor_config_setter()
120119

121120
activation_dtype = config.activation_dtype
122121
weight_dtype = config.weight_dtype
@@ -151,7 +150,90 @@ def _mx_inference_linear_transform(
151150
return module
152151

153152

153+
def _get_nvfp4_dtype():
154+
"""Factory function for NVFP4 dtype defaults."""
155+
if not TORCH_VERSION_AT_LEAST_2_8:
156+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
157+
return torch.float4_e2m1fn_x2
158+
159+
160+
@dataclass
161+
class NVFP4InferenceConfig(AOBaseConfig):
162+
"""
163+
NVIDIA FP4 (NVFP4) Inference Quantization Configuration
164+
165+
This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.
166+
It provides defaults optimized for NVFP4:
167+
- Data: float4_e2m1fn_x2
168+
- Scales: float8_e4m3fn (UE4M3)
169+
- Block size: 16 (required for NVFP4)
170+
- CUBLAS kernel (optimized for VEC16_UE4M3)
171+
"""
172+
173+
block_size: int = 16 # NVFP4 requires block size 16
174+
175+
# NVFP4 uses FP4 data
176+
activation_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
177+
weight_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
178+
179+
# NVFP4 uses E4M3 scales
180+
scale_dtype: torch.dtype = torch.float8_e4m3fn
181+
182+
# CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support
183+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
184+
185+
# Matrix multiplication configuration
186+
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC
187+
188+
def __post_init__(self):
189+
# Validate NVFP4 constraints
190+
if not TORCH_VERSION_AT_LEAST_2_8:
191+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
192+
193+
assert self.activation_dtype == torch.float4_e2m1fn_x2, (
194+
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"
195+
)
196+
assert self.weight_dtype == torch.float4_e2m1fn_x2, (
197+
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"
198+
)
199+
assert self.scale_dtype == torch.float8_e4m3fn, (
200+
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"
201+
)
202+
assert self.block_size == 16, (
203+
f"NVFP4 requires block_size=16, got {self.block_size}"
204+
)
205+
206+
207+
@register_quantize_module_handler(NVFP4InferenceConfig)
208+
def _nvfp4_inference_linear_transform(
209+
module: torch.nn.Module, config: NVFP4InferenceConfig
210+
):
211+
"""Quantization handler for NVFP4InferenceConfig"""
212+
assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"
213+
214+
weight = module.weight
215+
assert weight.dtype == torch.bfloat16, (
216+
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
217+
)
218+
219+
quantized_weight = NVFP4Tensor.to_nvfp4(
220+
weight,
221+
block_size=config.block_size,
222+
mm_config=config.mm_config,
223+
)
224+
225+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
226+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
227+
return module
228+
229+
154230
if TORCH_VERSION_AT_LEAST_2_5:
155231
torch.serialization.add_safe_globals(
156-
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
232+
[
233+
MXTensor,
234+
NVFP4Tensor,
235+
NVFP4MMConfig,
236+
MXGemmKernelChoice,
237+
_input_activation_quant_func_mxfp,
238+
]
157239
)

0 commit comments

Comments
 (0)