Skip to content

Commit d5bded3

Browse files
committed
NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 4e25496 commit d5bded3

File tree

6 files changed

+924
-15
lines changed

6 files changed

+924
-15
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch.nn.functional as F
1213

1314
from torchao.prototype.mx_formats.config import (
1415
MXGemmKernelChoice,
@@ -25,7 +26,11 @@
2526
MXInferenceLinear,
2627
MXLinear,
2728
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
29+
from torchao.prototype.mx_formats.mx_subclass import (
30+
MXFPInferenceConfig,
31+
NVFP4InferenceConfig,
32+
NVFP4MMConfig,
33+
)
2934
from torchao.quantization import quantize_
3035
from torchao.quantization.utils import compute_error
3136
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +446,131 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441446
assert sqnr >= SQNR_THRESHOLD, (
442447
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443448
)
449+
450+
451+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
452+
@pytest.mark.skipif(
453+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
454+
)
455+
@pytest.mark.skipif(
456+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
457+
)
458+
@pytest.mark.parametrize("bias", [True, False])
459+
@pytest.mark.parametrize("compile", [True, False])
460+
@pytest.mark.parametrize(
461+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
462+
)
463+
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
464+
@torch.no_grad()
465+
@skip_if_rocm("ROCm float4 gemm require gfx950")
466+
def test_inference_subclass_nvfp4(
467+
bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype
468+
):
469+
"""
470+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
471+
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
472+
"""
473+
if bias and inpt_dtype == torch.float32:
474+
pytest.xfail("Bias is not supported when module weight is in fp32")
475+
476+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
477+
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
478+
m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda")
479+
m_mx = copy.deepcopy(m)
480+
481+
config = NVFP4InferenceConfig(mm_config=mm_config)
482+
quantize_(m_mx, config=config)
483+
484+
if compile:
485+
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")
486+
487+
x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype)
488+
y_ref = m(x)
489+
y_mx = m_mx(x)
490+
sqnr = compute_error(y_ref, y_mx)
491+
492+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
493+
SQNR_THRESHOLD = 18.0
494+
else:
495+
SQNR_THRESHOLD = 15.0
496+
497+
assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
498+
assert sqnr >= SQNR_THRESHOLD, (
499+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
500+
)
501+
502+
503+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
504+
@pytest.mark.skipif(
505+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
506+
)
507+
@pytest.mark.skipif(
508+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
509+
)
510+
@pytest.mark.parametrize("use_gelu", [True, False])
511+
@pytest.mark.parametrize(
512+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
513+
)
514+
@pytest.mark.parametrize("compile", [False])
515+
@pytest.mark.parametrize("bias", [True, False])
516+
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
517+
@torch.no_grad()
518+
@skip_if_rocm("ROCm float4 gemm require gfx950")
519+
def test_nvfp4_matmul_with_amax(
520+
use_gelu: bool,
521+
mm_config: NVFP4MMConfig,
522+
compile: bool,
523+
bias: bool,
524+
inpt_dtype: torch.dtype,
525+
):
526+
from torchao.prototype.mx_formats.nvfp4_tensor import (
527+
NVFP4Tensor,
528+
per_tensor_amax_to_scale,
529+
)
530+
531+
if bias and inpt_dtype == torch.float32:
532+
pytest.xfail("Bias is not supported when module weight is in fp32")
533+
534+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
535+
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
536+
537+
m, k, n = 64, 256, 128
538+
539+
# Create activation tensor
540+
if use_gelu:
541+
x = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
542+
A = torch.nn.functional.gelu(x)
543+
else:
544+
A = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
545+
546+
B = torch.randn(n, k, dtype=inpt_dtype, device="cuda")
547+
bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None
548+
549+
# Compute reference
550+
C_ref = F.linear(A, B, bias_tensor)
551+
552+
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
553+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
554+
A_nvfp4 = NVFP4Tensor.to_nvfp4(
555+
A,
556+
per_tensor_scale=a_scale,
557+
mm_config=mm_config,
558+
)
559+
B_nvfp4 = NVFP4Tensor.to_nvfp4(
560+
B,
561+
per_tensor_scale=b_scale,
562+
mm_config=mm_config,
563+
)
564+
565+
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear
566+
567+
C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor)
568+
assert C_nvfp4.dtype == inpt_dtype, (
569+
f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}"
570+
)
571+
572+
sqnr = compute_error(C_ref, C_nvfp4)
573+
SQNR_THRESHOLD = 16.0
574+
assert sqnr >= SQNR_THRESHOLD, (
575+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
576+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP6_E2M3,
1616
DTYPE_FP6_E3M2,
17+
F4_E2M1_MAX,
1718
SUPPORTED_ELEM_DTYPES,
1819
)
1920
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
@@ -591,3 +592,68 @@ def to_f8(x):
591592
torch.testing.assert_close(
592593
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
593594
)
595+
596+
597+
@pytest.mark.parametrize(
598+
"dtype,shape,use_per_tensor_scale",
599+
[
600+
(torch.bfloat16, (32, 64), False),
601+
(torch.float32, (64, 128), False),
602+
(torch.bfloat16, (128, 256), False),
603+
(torch.bfloat16, (64, 128), True),
604+
],
605+
)
606+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
607+
@pytest.mark.skipif(
608+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
609+
)
610+
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
611+
from torchao.prototype.mx_formats.nvfp4_tensor import (
612+
NVFP4Tensor,
613+
per_tensor_amax_to_scale,
614+
)
615+
616+
x = torch.randn(shape, dtype=dtype, device="cuda")
617+
if use_per_tensor_scale:
618+
tensor_amax = torch.max(torch.abs(x))
619+
scale = per_tensor_amax_to_scale(tensor_amax)
620+
else:
621+
scale = None
622+
623+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
624+
x_reconstructed = x_nvfp4.to_dtype(dtype)
625+
626+
def assert_sqnr_gt_threshold(orig, new, threshold):
627+
sqnr = compute_error(orig, new)
628+
if torch.all(torch.isnan(sqnr)):
629+
# if both operands are full of zeroes, sqnr is nan and this is ok
630+
# test for this explicitly
631+
assert torch.all(orig == 0) and torch.all(new == 0)
632+
else:
633+
assert sqnr >= threshold
634+
635+
reconstructed_amax = x_nvfp4.get_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
636+
max_abs = torch.amax(
637+
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
638+
).unsqueeze(-1)
639+
640+
assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
641+
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)
642+
643+
assert x.shape == x_reconstructed.shape, (
644+
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
645+
)
646+
assert x.dtype == x_reconstructed.dtype, (
647+
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
648+
)
649+
650+
x_nvfp4_t = x_nvfp4.t()
651+
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
652+
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
653+
654+
assert x.t().shape == x_reconstructed_t.shape, (
655+
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
656+
)
657+
assert x.t().dtype == x_reconstructed_t.dtype, (
658+
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
659+
)

torchao/prototype/mx_formats/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
)
77

88
# Note: Prototype and subject to change
9-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
9+
from torchao.prototype.mx_formats.mx_subclass import (
10+
MXFPInferenceConfig,
11+
NVFP4InferenceConfig,
12+
NVFP4MMConfig,
13+
)
1014

1115
# import mx_linear here to register the quantize_ transform logic
1216
# ruff: noqa: I001
@@ -18,4 +22,6 @@
1822
"MXLinearConfig",
1923
"MXLinearRecipeName",
2024
"MXFPInferenceConfig",
25+
"NVFP4InferenceConfig",
26+
"NVFP4MMConfig",
2127
]

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

0 commit comments

Comments
 (0)