@@ -64,10 +64,14 @@ def run(
64
64
):
65
65
device = "cuda"
66
66
# TODO(future PR): this is ugly
67
- assert recipe in ("tensorwise" , "rowwise" , "mxfp8_cublas" , "mxfp4_cutlass" ), (
68
- "unsupported"
69
- )
70
- use_fp4 = recipe == "mxfp4_cutlass"
67
+ assert recipe in (
68
+ "tensorwise" ,
69
+ "rowwise" ,
70
+ "mxfp8_cublas" ,
71
+ "mxfp4_cutlass" ,
72
+ "nvfp4" ,
73
+ ), "unsupported"
74
+ use_fp4 = recipe in ("mxfp4_cutlass" , "nvfp4" )
71
75
72
76
specs = get_specs ()
73
77
bf16_peak_tops = specs ["bf16_peak_tops" ]
@@ -118,11 +122,20 @@ def run(
118
122
A_hp = torch .randn (M , K , device = device )
119
123
B_hp_t = torch .randn (N , K , device = device )
120
124
121
- if use_fp4 :
125
+ if recipe == "mxfp4_cutlass" :
122
126
_ , A = to_mx (A_hp , torch .float4_e2m1fn_x2 , 32 )
123
127
_ , Bt = to_mx (B_hp_t , torch .float4_e2m1fn_x2 , 32 )
124
128
B = Bt .contiguous ().T
125
129
peak_tops = fp4_peak_tops
130
+ elif recipe == "nvfp4" :
131
+ from torchao .prototype .mx_formats .nvfp4_tensor import nvfp4_quantize
132
+
133
+ # Quantize tensors to nvfp4 format - get blockwise scales
134
+ A_scales , A_data = nvfp4_quantize (A_hp , block_size = 16 )
135
+ B_scales , B_data = nvfp4_quantize (B_hp_t , block_size = 16 )
136
+ A = A_data .view (torch .float4_e2m1fn_x2 )
137
+ B = B_data .view (torch .float4_e2m1fn_x2 ).T
138
+ peak_tops = fp4_peak_tops
126
139
else :
127
140
# raw float8 matmul (upper bound for what we can achive in eager mode)
128
141
# TODO(future): add e5m2
@@ -140,6 +153,10 @@ def run(
140
153
elif recipe in ("mxfp8_cublas" , "mxfp4_cutlass" ):
141
154
scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
142
155
scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
156
+ elif recipe == "nvfp4" :
157
+ # Use the blockwise scales from nvfp4_quantize
158
+ scale_a = A_scales .view (torch .float8_e4m3fn )
159
+ scale_b = B_scales .view (torch .float8_e4m3fn )
143
160
else :
144
161
assert False , f"unknown recipe { recipe } "
145
162
@@ -155,7 +172,17 @@ def do_matmul_mxfp4(A, B):
155
172
nonlocal scale_b
156
173
return mx_fp4_bf16 (A , B , scale_a , scale_b )
157
174
158
- do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
175
+ def do_matmul_nvfp4 (A , B ):
176
+ nonlocal scale_a
177
+ nonlocal scale_b
178
+ return torch ._scaled_mm (A , B , scale_a , scale_b , out_dtype = dtype )
179
+
180
+ if recipe == "mxfp4_cutlass" :
181
+ do_matmul = do_matmul_mxfp4
182
+ elif recipe == "nvfp4" :
183
+ do_matmul = do_matmul_nvfp4
184
+ else :
185
+ do_matmul = do_matmul_fp8
159
186
160
187
time_sec , tops_sec , pct_top_peak = do_benchmarks (
161
188
tops , peak_tops , use_gpu_kernel_time , do_matmul , A , B
@@ -164,7 +191,11 @@ def do_matmul_mxfp4(A, B):
164
191
f"time_sec { time_sec :.2E} , tops/sec { tops_sec :.2E} , pct_peak { pct_top_peak :.3f} "
165
192
)
166
193
167
- del A , B , scale_a , scale_b
194
+ del A , B
195
+ if scale_a is not None :
196
+ del scale_a
197
+ if scale_b is not None :
198
+ del scale_b
168
199
169
200
results .append (
170
201
[
0 commit comments