@@ -64,10 +64,14 @@ def run(
64
64
):
65
65
device = "cuda"
66
66
# TODO(future PR): this is ugly
67
- assert recipe in ("tensorwise" , "rowwise" , "mxfp8_cublas" , "mxfp4_cutlass" ), (
68
- "unsupported"
69
- )
70
- use_fp4 = recipe == "mxfp4_cutlass"
67
+ assert recipe in (
68
+ "tensorwise" ,
69
+ "rowwise" ,
70
+ "mxfp8_cublas" ,
71
+ "mxfp4_cutlass" ,
72
+ "nvfp4" ,
73
+ ), "unsupported"
74
+ use_fp4 = recipe in ("mxfp4_cutlass" , "nvfp4" )
71
75
72
76
specs = get_specs ()
73
77
bf16_peak_tops = specs ["bf16_peak_tops" ]
@@ -118,11 +122,19 @@ def run(
118
122
A_hp = torch .randn (M , K , device = device )
119
123
B_hp_t = torch .randn (N , K , device = device )
120
124
121
- if use_fp4 :
125
+ if recipe == "mxfp4_cutlass" :
122
126
_ , A = to_mx (A_hp , torch .float4_e2m1fn_x2 , 32 )
123
127
_ , Bt = to_mx (B_hp_t , torch .float4_e2m1fn_x2 , 32 )
124
128
B = Bt .contiguous ().T
125
129
peak_tops = fp4_peak_tops
130
+ elif recipe == "nvfp4" :
131
+ from torchao .prototype .mx_formats .nvfp4_tensor import nvfp4_quantize
132
+
133
+ A_scales , A_data = nvfp4_quantize (A_hp , block_size = 16 )
134
+ B_scales , B_data = nvfp4_quantize (B_hp_t , block_size = 16 )
135
+ A = A_data .view (torch .float4_e2m1fn_x2 )
136
+ B = B_data .view (torch .float4_e2m1fn_x2 ).T
137
+ peak_tops = fp4_peak_tops
126
138
else :
127
139
# raw float8 matmul (upper bound for what we can achive in eager mode)
128
140
# TODO(future): add e5m2
@@ -140,6 +152,10 @@ def run(
140
152
elif recipe in ("mxfp8_cublas" , "mxfp4_cutlass" ):
141
153
scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
142
154
scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
155
+ elif recipe == "nvfp4" :
156
+ # Use the blockwise scales from nvfp4_quantize
157
+ scale_a = A_scales .view (torch .float8_e4m3fn )
158
+ scale_b = B_scales .view (torch .float8_e4m3fn )
143
159
else :
144
160
assert False , f"unknown recipe { recipe } "
145
161
@@ -155,7 +171,17 @@ def do_matmul_mxfp4(A, B):
155
171
nonlocal scale_b
156
172
return mx_fp4_bf16 (A , B , scale_a , scale_b )
157
173
158
- do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
174
+ def do_matmul_nvfp4 (A , B ):
175
+ nonlocal scale_a
176
+ nonlocal scale_b
177
+ return torch ._scaled_mm (A , B , scale_a , scale_b , out_dtype = dtype )
178
+
179
+ if recipe == "mxfp4_cutlass" :
180
+ do_matmul = do_matmul_mxfp4
181
+ elif recipe == "nvfp4" :
182
+ do_matmul = do_matmul_nvfp4
183
+ else :
184
+ do_matmul = do_matmul_fp8
159
185
160
186
time_sec , tops_sec , pct_top_peak = do_benchmarks (
161
187
tops , peak_tops , use_gpu_kernel_time , do_matmul , A , B
@@ -164,7 +190,11 @@ def do_matmul_mxfp4(A, B):
164
190
f"time_sec { time_sec :.2E} , tops/sec { tops_sec :.2E} , pct_peak { pct_top_peak :.3f} "
165
191
)
166
192
167
- del A , B , scale_a , scale_b
193
+ del A , B
194
+ if scale_a is not None :
195
+ del scale_a
196
+ if scale_b is not None :
197
+ del scale_b
168
198
169
199
results .append (
170
200
[
0 commit comments