Skip to content

Commit 034f892

Browse files
committed
WIP NVfp4
stack-info: PR: #2408, branch: drisspg/stack/78
1 parent 101c039 commit 034f892

File tree

6 files changed

+597
-54
lines changed

6 files changed

+597
-54
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28-
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
28+
from torchao.prototype.mx_formats.mx_subclass import (
29+
MXFPInferenceConfig,
30+
NVFP4InferenceConfig,
31+
)
2932
from torchao.quantization import quantize_
3033
from torchao.quantization.utils import compute_error
3134
from torchao.testing.utils import skip_if_rocm
@@ -441,3 +444,36 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
441444
assert sqnr >= SQNR_THRESHOLD, (
442445
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
443446
)
447+
448+
449+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450+
@pytest.mark.skipif(
451+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
452+
)
453+
@pytest.mark.skipif(
454+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
455+
)
456+
@pytest.mark.parametrize("bias", [True, False])
457+
@pytest.mark.parametrize("compile", [True, False])
458+
@torch.no_grad()
459+
@skip_if_rocm("ROCm float4 gemm require gfx950")
460+
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
461+
"""
462+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
463+
"""
464+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
465+
m_mx = copy.deepcopy(m)
466+
467+
config = NVFP4InferenceConfig()
468+
quantize_(m_mx, config=config)
469+
if compile:
470+
m_mx = torch.compile(m_mx, fullgraph=True)
471+
472+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
473+
y_ref = m(x)
474+
y_mx = m_mx(x)
475+
sqnr = compute_error(y_ref, y_mx)
476+
SQNR_THRESHOLD = 15.0 # Float4 threshold
477+
assert sqnr >= SQNR_THRESHOLD, (
478+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
479+
)

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5757
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5858
)
5959
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
60-
assert block_size == 32, (
61-
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
60+
assert block_size in [16, 32], (
61+
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
6262
)
63-
valid_dtypes = [torch.float8_e4m3fn]
63+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
6464
assert elem_dtype in valid_dtypes, (
6565
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
6666
)

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
104104
w_elem_dtype,
105105
block_size,
106106
weight_hp.dtype,
107+
None, # scale_dtype
107108
False,
108109
gemm_kernel_choice,
109110
False,
@@ -133,6 +134,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
133134
grad_elem_dtype,
134135
block_size,
135136
grad_output_hp_r.dtype,
137+
None, # scale_dtype
136138
False,
137139
gemm_kernel_choice,
138140
False,
@@ -155,6 +157,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
155157
in_elem_dtype,
156158
block_size,
157159
input_hp_r.dtype,
160+
None, # scale_dtype
158161
False,
159162
gemm_kernel_choice,
160163
False,

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 139 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501
3535
MXTensor,
36+
NVFP4Tensor,
3637
tensor_size_hp_to_fp4x2,
3738
tensor_size_hpx3_to_fp6x4,
3839
)
@@ -93,8 +94,8 @@ def _addmm_mx_dispatch(
9394
M, K, N = a.shape[0], a.shape[1], b.shape[1]
9495
assert a._data.is_contiguous()
9596
assert b._data.t().is_contiguous()
96-
assert a._block_size == 32, f"Invalid block size {a._block_size}"
97-
assert b._block_size == 32, f"Invalid block size {b._block_size}"
97+
assert a._block_size in [16, 32], f"Invalid block size {a._block_size}"
98+
assert b._block_size in [16, 32], f"Invalid block size {b._block_size}"
9899

99100
a_scale = a._scale_e8m0.view(M, K // a._block_size)
100101
b_scale = b._scale_e8m0.view(N, K // b._block_size)
@@ -144,42 +145,97 @@ def _addmm_mx_dispatch(
144145
return res
145146

146147

148+
def _addmm_nvfp4_dispatch(
149+
a: NVFP4Tensor, b: NVFP4Tensor, aten_op, bias: Optional[torch.Tensor] = None
150+
) -> torch.Tensor:
151+
"""
152+
Core implementation for NVFP4Tensor operations
153+
Uses E4M3 scales and always uses CUBLAS for FP4 operations
154+
"""
155+
# NVFP4 operations with E4M3 scales
156+
M, K, N = a.shape[0], a.shape[1], b.shape[1]
157+
assert a._data.is_contiguous()
158+
assert b._data.t().is_contiguous()
159+
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
160+
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
161+
162+
# NVFP4 uses E4M3 scales, not E8M0
163+
a_scale = a._scale_e4m3.view(M, K // a._block_size)
164+
b_scale = b._scale_e4m3.view(N, K // b._block_size)
165+
a_scale_block = to_blocked(a_scale)
166+
b_scale_block = to_blocked(b_scale)
167+
168+
# NVFP4 always uses CUBLAS with VEC16_UE4M3 scale mode
169+
res = torch._scaled_mm(
170+
a._data,
171+
b._data,
172+
a_scale_block.view(torch.float8_e4m3fn),
173+
b_scale_block.view(torch.float8_e4m3fn),
174+
bias=bias,
175+
out_dtype=torch.bfloat16,
176+
)
177+
178+
return res
179+
180+
147181
@implements([aten.mm.default, aten.matmul.default])
148182
def mx_mm(func, types, args, kwargs):
149183
a = args[0]
150184
b = args[1]
151-
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
152185

153-
return _addmm_mx_dispatch(a, b, func)
186+
# Handle both MXTensor and NVFP4Tensor
187+
if isinstance(a, MXTensor) and isinstance(b, MXTensor):
188+
return _addmm_mx_dispatch(a, b, func)
189+
elif isinstance(a, NVFP4Tensor) and isinstance(b, NVFP4Tensor):
190+
return _addmm_nvfp4_dispatch(a, b, func)
191+
else:
192+
raise ValueError(f"Unsupported tensor types: {type(a)}, {type(b)}")
154193

155194

156195
@implements([aten.addmm.default])
157196
def mx_addmm(func, types, args, kwargs):
158-
assert (
159-
isinstance(args[0], torch.Tensor)
160-
and isinstance(args[1], MXTensor)
161-
and isinstance(args[2], MXTensor)
162-
)
163197
bias = args[0]
164198
a = args[1]
165199
b = args[2]
166-
return _addmm_mx_dispatch(a, b, func, bias=bias)
200+
201+
assert isinstance(bias, torch.Tensor), (
202+
f"Bias must be torch.Tensor, got {type(bias)}"
203+
)
204+
205+
# Handle both MXTensor and NVFP4Tensor
206+
if isinstance(a, MXTensor) and isinstance(b, MXTensor):
207+
return _addmm_mx_dispatch(a, b, func, bias=bias)
208+
elif isinstance(a, NVFP4Tensor) and isinstance(b, NVFP4Tensor):
209+
return _addmm_nvfp4_dispatch(a, b, func, bias=bias)
210+
else:
211+
raise ValueError(f"Unsupported tensor types: {type(a)}, {type(b)}")
167212

168213

169214
@implements([aten.t.default])
170215
def mx_t(func, types, args, kwargs):
171216
# For now, only transpose(input, 0, 1) is supported.
172217
old = args[0]
173-
new = MXTensor(
174-
old._scale_e8m0,
175-
old._data.t(),
176-
old._elem_dtype,
177-
old._block_size,
178-
old._orig_dtype,
179-
old._use_fp4_custom_triton_dequant_kernel,
180-
old._gemm_kernel_choice,
181-
old._pack_fp6,
182-
)
218+
219+
if isinstance(old, MXTensor):
220+
new = MXTensor(
221+
old._scale_e8m0,
222+
old._data.t(),
223+
old._elem_dtype,
224+
old._block_size,
225+
old._orig_dtype,
226+
old._use_fp4_custom_triton_dequant_kernel,
227+
old._gemm_kernel_choice,
228+
old._pack_fp6,
229+
)
230+
elif isinstance(old, NVFP4Tensor):
231+
new = NVFP4Tensor(
232+
old._scale_e4m3,
233+
old._data.t(),
234+
old._block_size,
235+
old._orig_dtype,
236+
)
237+
else:
238+
raise ValueError(f"Unsupported tensor type: {type(old)}")
183239
return new
184240

185241

@@ -205,25 +261,43 @@ def unwrap(x):
205261

206262
@implements([aten.view.default])
207263
def mx_view_op(func, types, args, kwargs):
208-
data = args[0]._data
264+
tensor = args[0]
265+
data = tensor._data
209266
new_size = args[1]
210-
if args[0]._elem_dtype == torch.float4_e2m1fn_x2:
211-
# special case fp4 as we pack two elements per byte
267+
268+
if isinstance(tensor, MXTensor):
269+
if tensor._elem_dtype == torch.float4_e2m1fn_x2:
270+
# special case fp4 as we pack two elements per byte
271+
new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous())
272+
elif (
273+
tensor._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and tensor._pack_fp6
274+
):
275+
# special case fp6 as we pack 4 elements in 3 bytes
276+
new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous())
277+
278+
new_data = func(data, new_size, *args[2:], **kwargs)
279+
return MXTensor(
280+
tensor._scale_e8m0,
281+
new_data,
282+
tensor._elem_dtype,
283+
tensor._block_size,
284+
tensor._orig_dtype,
285+
tensor._use_fp4_custom_triton_dequant_kernel,
286+
tensor._gemm_kernel_choice,
287+
tensor._pack_fp6,
288+
)
289+
elif isinstance(tensor, NVFP4Tensor):
290+
# NVFP4 is always fp4 packed
212291
new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous())
213-
elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6:
214-
# special case fp6 as we pack 4 elements in 3 bytes
215-
new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous())
216-
new_data = func(data, new_size, *args[2:], **kwargs)
217-
return MXTensor(
218-
args[0]._scale_e8m0,
219-
new_data,
220-
args[0]._elem_dtype,
221-
args[0]._block_size,
222-
args[0]._orig_dtype,
223-
args[0]._use_fp4_custom_triton_dequant_kernel,
224-
args[0]._gemm_kernel_choice,
225-
args[0]._pack_fp6,
226-
)
292+
new_data = func(data, new_size, *args[2:], **kwargs)
293+
return NVFP4Tensor(
294+
tensor._scale_e4m3,
295+
new_data,
296+
tensor._block_size,
297+
tensor._orig_dtype,
298+
)
299+
else:
300+
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
227301

228302

229303
@implements([aten.slice.Tensor])
@@ -235,8 +309,15 @@ def mx_slice(func, types, args, kwargs):
235309

236310
M, K = x.shape[0], x.shape[1]
237311

238-
# TODO why doesn't scale have shape?
239-
scale_shaped = x._scale_e8m0.view(M, K // x._block_size)
312+
# Handle different scale tensors for different tensor types
313+
if isinstance(x, MXTensor):
314+
scale_tensor = x._scale_e8m0
315+
elif isinstance(x, NVFP4Tensor):
316+
scale_tensor = x._scale_e4m3
317+
else:
318+
raise ValueError(f"Unsupported tensor type: {type(x)}")
319+
320+
scale_shaped = scale_tensor.view(M, K // x._block_size)
240321

241322
if dim == 0:
242323
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
@@ -267,15 +348,14 @@ def mx_slice(func, types, args, kwargs):
267348
scale_shaped, 1, start_block, end_block, step
268349
).flatten()
269350
else:
351+
tensor_name = "MXTensor/NVFP4Tensor"
270352
raise ValueError(
271-
f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}"
353+
f"{tensor_name} only supports slicing along dimensions 0 and 1, got dim={dim}"
272354
)
273355

274-
return return_and_correct_aliasing(
275-
func,
276-
args,
277-
kwargs,
278-
MXTensor(
356+
# Create appropriate tensor type
357+
if isinstance(x, MXTensor):
358+
result_tensor = MXTensor(
279359
sliced_scale,
280360
sliced_data,
281361
x._elem_dtype,
@@ -284,7 +364,20 @@ def mx_slice(func, types, args, kwargs):
284364
x._use_fp4_custom_triton_dequant_kernel,
285365
x._gemm_kernel_choice,
286366
x._pack_fp6,
287-
),
367+
)
368+
else: # NVFP4Tensor
369+
result_tensor = NVFP4Tensor(
370+
sliced_scale,
371+
sliced_data,
372+
x._block_size,
373+
x._orig_dtype,
374+
)
375+
376+
return return_and_correct_aliasing(
377+
func,
378+
args,
379+
kwargs,
380+
result_tensor,
288381
)
289382

290383

0 commit comments

Comments
 (0)