Skip to content

Commit d85d39a

Browse files
committed
NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 101c039 commit d85d39a

File tree

6 files changed

+805
-9
lines changed

6 files changed

+805
-9
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
28+
from torchao.prototype.mx_formats.mx_subclass import (
29+
MXFPInferenceConfig,
30+
NVFP4InferenceConfig,
31+
)
2932
from torchao.quantization import quantize_
3033
from torchao.quantization.utils import compute_error
3134
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +444,92 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441444
assert sqnr >= SQNR_THRESHOLD, (
442445
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443446
)
447+
448+
449+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450+
@pytest.mark.skipif(
451+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
452+
)
453+
@pytest.mark.skipif(
454+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
455+
)
456+
@pytest.mark.parametrize("bias", [True, False])
457+
@pytest.mark.parametrize("compile", [True, False])
458+
@torch.no_grad()
459+
@skip_if_rocm("ROCm float4 gemm require gfx950")
460+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
461+
"""
462+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
463+
"""
464+
m = nn.Linear(64, 256, bias=bias, dtype=torch.bfloat16, device="cuda")
465+
m_mx = copy.deepcopy(m)
466+
467+
config = NVFP4InferenceConfig()
468+
quantize_(m_mx, config=config)
469+
if compile:
470+
m_mx = torch.compile(m_mx, fullgraph=True)
471+
472+
x = torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)
473+
y_ref = m(x)
474+
y_mx = m_mx(x)
475+
sqnr = compute_error(y_ref, y_mx)
476+
SQNR_THRESHOLD = 15.0
477+
assert sqnr >= SQNR_THRESHOLD, (
478+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
479+
)
480+
481+
482+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
483+
@pytest.mark.skipif(
484+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
485+
)
486+
@pytest.mark.parametrize("use_gelu", [True, False])
487+
@pytest.mark.parametrize("emulate", [True, False])
488+
@pytest.mark.parametrize("compile", [False])
489+
@torch.no_grad()
490+
def test_nvfp4_matmul_with_amax_emulate(use_gelu: bool, emulate: bool, compile: bool):
491+
from torchao.prototype.mx_formats.nvfp4_tensor import (
492+
NVFP4Tensor,
493+
per_tensor_amax_to_scale,
494+
)
495+
496+
m, k, n = 64, 256, 128
497+
498+
# Create activation tensor
499+
if use_gelu:
500+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
501+
A = torch.nn.functional.gelu(x)
502+
else:
503+
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
504+
505+
# Create weight tensor (always Gaussian)
506+
B = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
507+
508+
# Compute reference
509+
C_ref = torch.matmul(A, B.t())
510+
511+
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
512+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
513+
514+
# Quantize with per-tensor amax
515+
A_nvfp4 = NVFP4Tensor.to_nvfp4(A, per_tensor_scale=a_scale)
516+
B_nvfp4 = NVFP4Tensor.to_nvfp4(B, per_tensor_scale=b_scale)
517+
518+
if emulate:
519+
# Cast back to original dtype and compute
520+
A_emulated = A_nvfp4.to_dtype(A.dtype)
521+
B_emulated = B_nvfp4.to_dtype(B.dtype)
522+
mm = torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
523+
C_emulated = mm(A_emulated, B_emulated.t())
524+
sqnr = compute_error(C_ref, C_emulated)
525+
else:
526+
# Use native nvfp4 matmul
527+
mm = torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
528+
C_nvfp4 = mm(A_nvfp4, B_nvfp4.t())
529+
sqnr = compute_error(C_ref, C_nvfp4)
530+
531+
# Check quality threshold
532+
SQNR_THRESHOLD = 16.0
533+
assert sqnr >= SQNR_THRESHOLD, (
534+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, emulate={emulate}, compile={compile}"
535+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP6_E2M3,
1616
DTYPE_FP6_E3M2,
17+
F4_E2M1_MAX,
1718
SUPPORTED_ELEM_DTYPES,
1819
)
1920
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
@@ -591,3 +592,53 @@ def to_f8(x):
591592
torch.testing.assert_close(
592593
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
593594
)
595+
596+
597+
@pytest.mark.parametrize(
598+
"dtype,shape,use_per_tensor_scale",
599+
[
600+
(torch.bfloat16, (32, 64), False),
601+
(torch.float32, (64, 128), False),
602+
(torch.bfloat16, (128, 256), False),
603+
(torch.bfloat16, (64, 128), True),
604+
],
605+
)
606+
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
607+
from torchao.prototype.mx_formats.nvfp4_tensor import (
608+
NVFP4Tensor,
609+
per_tensor_amax_to_scale,
610+
)
611+
612+
x = torch.randn(shape, dtype=dtype, device="cuda")
613+
if use_per_tensor_scale:
614+
tensor_amax = torch.max(torch.abs(x))
615+
scale = per_tensor_amax_to_scale(tensor_amax)
616+
else:
617+
scale = None
618+
619+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
620+
x_reconstructed = x_nvfp4.to_dtype(dtype)
621+
622+
def assert_sqnr_gt_threshold(orig, new, threshold):
623+
sqnr = compute_error(orig, new)
624+
if torch.all(torch.isnan(sqnr)):
625+
# if both operands are full of zeroes, sqnr is nan and this is ok
626+
# test for this explicitly
627+
assert torch.all(orig == 0) and torch.all(new == 0)
628+
else:
629+
assert sqnr >= threshold
630+
631+
reconstructed_amax = x_nvfp4.get_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
632+
max_abs = torch.amax(
633+
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
634+
).unsqueeze(-1)
635+
636+
assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
637+
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)
638+
639+
assert x.shape == x_reconstructed.shape, (
640+
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
641+
)
642+
assert x.dtype == x_reconstructed.dtype, (
643+
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
644+
)

torchao/prototype/mx_formats/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
)
77

88
# Note: Prototype and subject to change
9-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
9+
from torchao.prototype.mx_formats.mx_subclass import (
10+
MXFPInferenceConfig,
11+
NVFP4InferenceConfig,
12+
)
1013

1114
# import mx_linear here to register the quantize_ transform logic
1215
# ruff: noqa: I001
@@ -18,4 +21,5 @@
1821
"MXLinearConfig",
1922
"MXLinearRecipeName",
2023
"MXFPInferenceConfig",
24+
"NVFP4InferenceConfig",
2125
]

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import types
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from typing import Optional
1010

1111
import torch
@@ -20,11 +20,16 @@
2020
_validate_gemm_kernel_choice,
2121
)
2222
from torchao.prototype.mx_formats.mx_tensor import MXTensor
23+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
2324
from torchao.quantization.quant_api import to_linear_activation_quantized
2425
from torchao.quantization.transform_module import (
2526
register_quantize_module_handler,
2627
)
27-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
28+
from torchao.utils import (
29+
TORCH_VERSION_AT_LEAST_2_5,
30+
TORCH_VERSION_AT_LEAST_2_8,
31+
is_sm_at_least_100,
32+
)
2833

2934

3035
# Note: This API is extra prototype and will change in the future
@@ -63,7 +68,7 @@ class MXFPInferenceConfig(AOBaseConfig):
6368

6469
block_size: int = 32
6570

66-
# Dtypes for Input and Weights
71+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
6772
activation_dtype: torch.dtype = torch.float8_e4m3fn
6873
weight_dtype: torch.dtype = torch.float8_e4m3fn
6974

@@ -151,7 +156,113 @@ def _mx_inference_linear_transform(
151156
return module
152157

153158

159+
def _get_nvfp4_dtype():
160+
"""Factory function for NVFP4 dtype defaults."""
161+
if not TORCH_VERSION_AT_LEAST_2_8:
162+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
163+
return torch.float4_e2m1fn_x2
164+
165+
166+
@dataclass
167+
class NVFP4InferenceConfig(AOBaseConfig):
168+
"""
169+
NVIDIA FP4 (NVFP4) Inference Quantization Configuration
170+
171+
This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.
172+
It provides defaults optimized for NVFP4:
173+
- Data: float4_e2m1fn_x2
174+
- Scales: float8_e4m3fn (UE4M3)
175+
- Block size: 16 (required for NVFP4)
176+
- CUBLAS kernel (optimized for VEC16_UE4M3)
177+
"""
178+
179+
block_size: int = 16 # NVFP4 requires block size 16
180+
181+
# NVFP4 uses FP4 data
182+
activation_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
183+
weight_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
184+
185+
# NVFP4 uses E4M3 scales
186+
scale_dtype: torch.dtype = torch.float8_e4m3fn
187+
188+
# CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support
189+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS
190+
191+
def __post_init__(self):
192+
# Validate NVFP4 constraints
193+
if not TORCH_VERSION_AT_LEAST_2_8:
194+
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
195+
196+
assert self.activation_dtype == torch.float4_e2m1fn_x2, (
197+
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"
198+
)
199+
assert self.weight_dtype == torch.float4_e2m1fn_x2, (
200+
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"
201+
)
202+
assert self.scale_dtype == torch.float8_e4m3fn, (
203+
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"
204+
)
205+
assert self.block_size == 16, (
206+
f"NVFP4 requires block_size=16, got {self.block_size}"
207+
)
208+
209+
210+
def _input_activation_quant_func_nvfp4(
211+
x: torch.Tensor,
212+
block_size: int = 16,
213+
scale: Optional[torch.Tensor] = None,
214+
):
215+
"""NVFP4-specific activation quantization function"""
216+
# TODO: scale for static quant
217+
activation = NVFP4Tensor.to_nvfp4(
218+
x,
219+
block_size=block_size,
220+
)
221+
return activation
222+
223+
224+
@register_quantize_module_handler(NVFP4InferenceConfig)
225+
def _nvfp4_inference_linear_transform(
226+
module: torch.nn.Module, config: NVFP4InferenceConfig
227+
):
228+
"""Quantization handler for NVFP4InferenceConfig"""
229+
assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"
230+
if config.set_inductor_config:
231+
torchao.quantization.utils.recommended_inductor_config_setter()
232+
233+
weight = module.weight
234+
assert weight.dtype == torch.bfloat16, (
235+
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
236+
)
237+
238+
# Convert weight to NVFP4 Tensor
239+
quantized_weight = NVFP4Tensor.to_nvfp4(
240+
weight,
241+
block_size=config.block_size,
242+
)
243+
244+
input_quant_func = _input_activation_quant_func_nvfp4
245+
input_quant_kwargs = {
246+
"block_size": config.block_size,
247+
"scale": None,
248+
}
249+
250+
quantized_weight = to_linear_activation_quantized(
251+
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
252+
)
253+
254+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
255+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
256+
return module
257+
258+
154259
if TORCH_VERSION_AT_LEAST_2_5:
155260
torch.serialization.add_safe_globals(
156-
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
261+
[
262+
MXTensor,
263+
NVFP4Tensor,
264+
MXGemmKernelChoice,
265+
_input_activation_quant_func_mxfp,
266+
_input_activation_quant_func_nvfp4,
267+
]
157268
)

0 commit comments

Comments
 (0)