Skip to content

Commit d2d4383

Browse files
committed
add-to-benchmarks
1 parent efdd0b1 commit d2d4383

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,14 @@ def run(
6464
):
6565
device = "cuda"
6666
# TODO(future PR): this is ugly
67-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas", "mxfp4_cutlass"), (
68-
"unsupported"
69-
)
70-
use_fp4 = recipe == "mxfp4_cutlass"
67+
assert recipe in (
68+
"tensorwise",
69+
"rowwise",
70+
"mxfp8_cublas",
71+
"mxfp4_cutlass",
72+
"nvfp4",
73+
), "unsupported"
74+
use_fp4 = recipe in ("mxfp4_cutlass", "nvfp4")
7175

7276
specs = get_specs()
7377
bf16_peak_tops = specs["bf16_peak_tops"]
@@ -118,11 +122,20 @@ def run(
118122
A_hp = torch.randn(M, K, device=device)
119123
B_hp_t = torch.randn(N, K, device=device)
120124

121-
if use_fp4:
125+
if recipe == "mxfp4_cutlass":
122126
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
123127
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
124128
B = Bt.contiguous().T
125129
peak_tops = fp4_peak_tops
130+
elif recipe == "nvfp4":
131+
from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize
132+
133+
# Quantize tensors to nvfp4 format - get blockwise scales
134+
A_scales, A_data = nvfp4_quantize(A_hp, block_size=16)
135+
B_scales, B_data = nvfp4_quantize(B_hp_t, block_size=16)
136+
A = A_data.view(torch.float4_e2m1fn_x2)
137+
B = B_data.view(torch.float4_e2m1fn_x2).T
138+
peak_tops = fp4_peak_tops
126139
else:
127140
# raw float8 matmul (upper bound for what we can achive in eager mode)
128141
# TODO(future): add e5m2
@@ -140,6 +153,10 @@ def run(
140153
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
141154
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
142155
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
156+
elif recipe == "nvfp4":
157+
# Use the blockwise scales from nvfp4_quantize
158+
scale_a = A_scales.view(torch.float8_e4m3fn)
159+
scale_b = B_scales.view(torch.float8_e4m3fn)
143160
else:
144161
assert False, f"unknown recipe {recipe}"
145162

@@ -155,7 +172,17 @@ def do_matmul_mxfp4(A, B):
155172
nonlocal scale_b
156173
return mx_fp4_bf16(A, B, scale_a, scale_b)
157174

158-
do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
175+
def do_matmul_nvfp4(A, B):
176+
nonlocal scale_a
177+
nonlocal scale_b
178+
return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype)
179+
180+
if recipe == "mxfp4_cutlass":
181+
do_matmul = do_matmul_mxfp4
182+
elif recipe == "nvfp4":
183+
do_matmul = do_matmul_nvfp4
184+
else:
185+
do_matmul = do_matmul_fp8
159186

160187
time_sec, tops_sec, pct_top_peak = do_benchmarks(
161188
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
@@ -164,7 +191,11 @@ def do_matmul_mxfp4(A, B):
164191
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
165192
)
166193

167-
del A, B, scale_a, scale_b
194+
del A, B
195+
if scale_a is not None:
196+
del scale_a
197+
if scale_b is not None:
198+
del scale_b
168199

169200
results.append(
170201
[

0 commit comments

Comments
 (0)