|
9 | 9 | import pytest
|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
| 12 | +import torch.nn.functional as F |
12 | 13 |
|
13 | 14 | from torchao.prototype.mx_formats.config import (
|
14 | 15 | MXGemmKernelChoice,
|
|
25 | 26 | MXInferenceLinear,
|
26 | 27 | MXLinear,
|
27 | 28 | )
|
28 |
| -from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig |
| 29 | +from torchao.prototype.mx_formats.mx_subclass import ( |
| 30 | + MXFPInferenceConfig, |
| 31 | + NVFP4InferenceConfig, |
| 32 | + NVFP4MMConfig, |
| 33 | +) |
29 | 34 | from torchao.quantization import quantize_
|
30 | 35 | from torchao.quantization.utils import compute_error
|
31 | 36 | from torchao.testing.utils import skip_if_rocm
|
@@ -441,3 +446,131 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
|
441 | 446 | assert sqnr >= SQNR_THRESHOLD, (
|
442 | 447 | f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
|
443 | 448 | )
|
| 449 | + |
| 450 | + |
| 451 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 452 | +@pytest.mark.skipif( |
| 453 | + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" |
| 454 | +) |
| 455 | +@pytest.mark.skipif( |
| 456 | + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm" |
| 457 | +) |
| 458 | +@pytest.mark.parametrize("bias", [True, False]) |
| 459 | +@pytest.mark.parametrize("compile", [True, False]) |
| 460 | +@pytest.mark.parametrize( |
| 461 | + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] |
| 462 | +) |
| 463 | +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) |
| 464 | +@torch.no_grad() |
| 465 | +@skip_if_rocm("ROCm float4 gemm require gfx950") |
| 466 | +def test_inference_subclass_nvfp4( |
| 467 | + bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype |
| 468 | +): |
| 469 | + """ |
| 470 | + Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16 |
| 471 | + Tests both DYNAMIC and WEIGHT_ONLY mm_config modes |
| 472 | + """ |
| 473 | + if bias and inpt_dtype == torch.float32: |
| 474 | + pytest.xfail("Bias is not supported when module weight is in fp32") |
| 475 | + |
| 476 | + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: |
| 477 | + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") |
| 478 | + m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda") |
| 479 | + m_mx = copy.deepcopy(m) |
| 480 | + |
| 481 | + config = NVFP4InferenceConfig(mm_config=mm_config) |
| 482 | + quantize_(m_mx, config=config) |
| 483 | + |
| 484 | + if compile: |
| 485 | + m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager") |
| 486 | + |
| 487 | + x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype) |
| 488 | + y_ref = m(x) |
| 489 | + y_mx = m_mx(x) |
| 490 | + sqnr = compute_error(y_ref, y_mx) |
| 491 | + |
| 492 | + if mm_config == NVFP4MMConfig.WEIGHT_ONLY: |
| 493 | + SQNR_THRESHOLD = 18.0 |
| 494 | + else: |
| 495 | + SQNR_THRESHOLD = 15.0 |
| 496 | + |
| 497 | + assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}" |
| 498 | + assert sqnr >= SQNR_THRESHOLD, ( |
| 499 | + f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" |
| 500 | + ) |
| 501 | + |
| 502 | + |
| 503 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 504 | +@pytest.mark.skipif( |
| 505 | + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" |
| 506 | +) |
| 507 | +@pytest.mark.skipif( |
| 508 | + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm" |
| 509 | +) |
| 510 | +@pytest.mark.parametrize("use_gelu", [True, False]) |
| 511 | +@pytest.mark.parametrize( |
| 512 | + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] |
| 513 | +) |
| 514 | +@pytest.mark.parametrize("compile", [False]) |
| 515 | +@pytest.mark.parametrize("bias", [True, False]) |
| 516 | +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) |
| 517 | +@torch.no_grad() |
| 518 | +@skip_if_rocm("ROCm float4 gemm require gfx950") |
| 519 | +def test_nvfp4_matmul_with_amax( |
| 520 | + use_gelu: bool, |
| 521 | + mm_config: NVFP4MMConfig, |
| 522 | + compile: bool, |
| 523 | + bias: bool, |
| 524 | + inpt_dtype: torch.dtype, |
| 525 | +): |
| 526 | + from torchao.prototype.mx_formats.nvfp4_tensor import ( |
| 527 | + NVFP4Tensor, |
| 528 | + per_tensor_amax_to_scale, |
| 529 | + ) |
| 530 | + |
| 531 | + if bias and inpt_dtype == torch.float32: |
| 532 | + pytest.xfail("Bias is not supported when module weight is in fp32") |
| 533 | + |
| 534 | + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: |
| 535 | + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") |
| 536 | + |
| 537 | + m, k, n = 64, 256, 128 |
| 538 | + |
| 539 | + # Create activation tensor |
| 540 | + if use_gelu: |
| 541 | + x = torch.randn(m, k, dtype=inpt_dtype, device="cuda") |
| 542 | + A = torch.nn.functional.gelu(x) |
| 543 | + else: |
| 544 | + A = torch.randn(m, k, dtype=inpt_dtype, device="cuda") |
| 545 | + |
| 546 | + B = torch.randn(n, k, dtype=inpt_dtype, device="cuda") |
| 547 | + bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None |
| 548 | + |
| 549 | + # Compute reference |
| 550 | + C_ref = F.linear(A, B, bias_tensor) |
| 551 | + |
| 552 | + a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A))) |
| 553 | + b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B))) |
| 554 | + A_nvfp4 = NVFP4Tensor.to_nvfp4( |
| 555 | + A, |
| 556 | + per_tensor_scale=a_scale, |
| 557 | + mm_config=mm_config, |
| 558 | + ) |
| 559 | + B_nvfp4 = NVFP4Tensor.to_nvfp4( |
| 560 | + B, |
| 561 | + per_tensor_scale=b_scale, |
| 562 | + mm_config=mm_config, |
| 563 | + ) |
| 564 | + |
| 565 | + func = torch.compile(F.linear, fullgraph=True) if compile else F.linear |
| 566 | + |
| 567 | + C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor) |
| 568 | + assert C_nvfp4.dtype == inpt_dtype, ( |
| 569 | + f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}" |
| 570 | + ) |
| 571 | + |
| 572 | + sqnr = compute_error(C_ref, C_nvfp4) |
| 573 | + SQNR_THRESHOLD = 16.0 |
| 574 | + assert sqnr >= SQNR_THRESHOLD, ( |
| 575 | + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" |
| 576 | + ) |
0 commit comments