Skip to content

Commit 6e2ade6

Browse files
committed
add-to-benchmarks
stack-info: PR: #2427, branch: drisspg/stack/79
1 parent d5bded3 commit 6e2ade6

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,14 @@ def run(
6464
):
6565
device = "cuda"
6666
# TODO(future PR): this is ugly
67-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas", "mxfp4_cutlass"), (
68-
"unsupported"
69-
)
70-
use_fp4 = recipe == "mxfp4_cutlass"
67+
assert recipe in (
68+
"tensorwise",
69+
"rowwise",
70+
"mxfp8_cublas",
71+
"mxfp4_cutlass",
72+
"nvfp4",
73+
), "unsupported"
74+
use_fp4 = recipe in ("mxfp4_cutlass", "nvfp4")
7175

7276
specs = get_specs()
7377
bf16_peak_tops = specs["bf16_peak_tops"]
@@ -118,11 +122,19 @@ def run(
118122
A_hp = torch.randn(M, K, device=device)
119123
B_hp_t = torch.randn(N, K, device=device)
120124

121-
if use_fp4:
125+
if recipe == "mxfp4_cutlass":
122126
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
123127
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
124128
B = Bt.contiguous().T
125129
peak_tops = fp4_peak_tops
130+
elif recipe == "nvfp4":
131+
from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize
132+
133+
A_scales, A_data = nvfp4_quantize(A_hp, block_size=16)
134+
B_scales, B_data = nvfp4_quantize(B_hp_t, block_size=16)
135+
A = A_data.view(torch.float4_e2m1fn_x2)
136+
B = B_data.view(torch.float4_e2m1fn_x2).T
137+
peak_tops = fp4_peak_tops
126138
else:
127139
# raw float8 matmul (upper bound for what we can achive in eager mode)
128140
# TODO(future): add e5m2
@@ -140,6 +152,10 @@ def run(
140152
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
141153
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
142154
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
155+
elif recipe == "nvfp4":
156+
# Use the blockwise scales from nvfp4_quantize
157+
scale_a = A_scales.view(torch.float8_e4m3fn)
158+
scale_b = B_scales.view(torch.float8_e4m3fn)
143159
else:
144160
assert False, f"unknown recipe {recipe}"
145161

@@ -155,7 +171,17 @@ def do_matmul_mxfp4(A, B):
155171
nonlocal scale_b
156172
return mx_fp4_bf16(A, B, scale_a, scale_b)
157173

158-
do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
174+
def do_matmul_nvfp4(A, B):
175+
nonlocal scale_a
176+
nonlocal scale_b
177+
return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype)
178+
179+
if recipe == "mxfp4_cutlass":
180+
do_matmul = do_matmul_mxfp4
181+
elif recipe == "nvfp4":
182+
do_matmul = do_matmul_nvfp4
183+
else:
184+
do_matmul = do_matmul_fp8
159185

160186
time_sec, tops_sec, pct_top_peak = do_benchmarks(
161187
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
@@ -164,7 +190,11 @@ def do_matmul_mxfp4(A, B):
164190
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
165191
)
166192

167-
del A, B, scale_a, scale_b
193+
del A, B
194+
if scale_a is not None:
195+
del scale_a
196+
if scale_b is not None:
197+
del scale_b
168198

169199
results.append(
170200
[

0 commit comments

Comments
 (0)