33
33
)
34
34
from torchao .prototype .mx_formats .mx_tensor import ( # noqa: E501
35
35
MXTensor ,
36
+ NVFP4Tensor ,
36
37
tensor_size_hp_to_fp4x2 ,
37
38
tensor_size_hpx3_to_fp6x4 ,
38
39
)
@@ -93,8 +94,8 @@ def _addmm_mx_dispatch(
93
94
M , K , N = a .shape [0 ], a .shape [1 ], b .shape [1 ]
94
95
assert a ._data .is_contiguous ()
95
96
assert b ._data .t ().is_contiguous ()
96
- assert a ._block_size == 32 , f"Invalid block size { a ._block_size } "
97
- assert b ._block_size == 32 , f"Invalid block size { b ._block_size } "
97
+ assert a ._block_size in [ 16 , 32 ] , f"Invalid block size { a ._block_size } "
98
+ assert b ._block_size in [ 16 , 32 ] , f"Invalid block size { b ._block_size } "
98
99
99
100
a_scale = a ._scale_e8m0 .view (M , K // a ._block_size )
100
101
b_scale = b ._scale_e8m0 .view (N , K // b ._block_size )
@@ -144,42 +145,97 @@ def _addmm_mx_dispatch(
144
145
return res
145
146
146
147
148
+ def _addmm_nvfp4_dispatch (
149
+ a : NVFP4Tensor , b : NVFP4Tensor , aten_op , bias : Optional [torch .Tensor ] = None
150
+ ) -> torch .Tensor :
151
+ """
152
+ Core implementation for NVFP4Tensor operations
153
+ Uses E4M3 scales and always uses CUBLAS for FP4 operations
154
+ """
155
+ # NVFP4 operations with E4M3 scales
156
+ M , K , N = a .shape [0 ], a .shape [1 ], b .shape [1 ]
157
+ assert a ._data .is_contiguous ()
158
+ assert b ._data .t ().is_contiguous ()
159
+ assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
160
+ assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
161
+
162
+ # NVFP4 uses E4M3 scales, not E8M0
163
+ a_scale = a ._scale_e4m3 .view (M , K // a ._block_size )
164
+ b_scale = b ._scale_e4m3 .view (N , K // b ._block_size )
165
+ a_scale_block = to_blocked (a_scale )
166
+ b_scale_block = to_blocked (b_scale )
167
+
168
+ # NVFP4 always uses CUBLAS with VEC16_UE4M3 scale mode
169
+ res = torch ._scaled_mm (
170
+ a ._data ,
171
+ b ._data ,
172
+ a_scale_block .view (torch .float8_e4m3fn ),
173
+ b_scale_block .view (torch .float8_e4m3fn ),
174
+ bias = bias ,
175
+ out_dtype = torch .bfloat16 ,
176
+ )
177
+
178
+ return res
179
+
180
+
147
181
@implements ([aten .mm .default , aten .matmul .default ])
148
182
def mx_mm (func , types , args , kwargs ):
149
183
a = args [0 ]
150
184
b = args [1 ]
151
- assert isinstance (a , MXTensor ) and isinstance (b , MXTensor )
152
185
153
- return _addmm_mx_dispatch (a , b , func )
186
+ # Handle both MXTensor and NVFP4Tensor
187
+ if isinstance (a , MXTensor ) and isinstance (b , MXTensor ):
188
+ return _addmm_mx_dispatch (a , b , func )
189
+ elif isinstance (a , NVFP4Tensor ) and isinstance (b , NVFP4Tensor ):
190
+ return _addmm_nvfp4_dispatch (a , b , func )
191
+ else :
192
+ raise ValueError (f"Unsupported tensor types: { type (a )} , { type (b )} " )
154
193
155
194
156
195
@implements ([aten .addmm .default ])
157
196
def mx_addmm (func , types , args , kwargs ):
158
- assert (
159
- isinstance (args [0 ], torch .Tensor )
160
- and isinstance (args [1 ], MXTensor )
161
- and isinstance (args [2 ], MXTensor )
162
- )
163
197
bias = args [0 ]
164
198
a = args [1 ]
165
199
b = args [2 ]
166
- return _addmm_mx_dispatch (a , b , func , bias = bias )
200
+
201
+ assert isinstance (bias , torch .Tensor ), (
202
+ f"Bias must be torch.Tensor, got { type (bias )} "
203
+ )
204
+
205
+ # Handle both MXTensor and NVFP4Tensor
206
+ if isinstance (a , MXTensor ) and isinstance (b , MXTensor ):
207
+ return _addmm_mx_dispatch (a , b , func , bias = bias )
208
+ elif isinstance (a , NVFP4Tensor ) and isinstance (b , NVFP4Tensor ):
209
+ return _addmm_nvfp4_dispatch (a , b , func , bias = bias )
210
+ else :
211
+ raise ValueError (f"Unsupported tensor types: { type (a )} , { type (b )} " )
167
212
168
213
169
214
@implements ([aten .t .default ])
170
215
def mx_t (func , types , args , kwargs ):
171
216
# For now, only transpose(input, 0, 1) is supported.
172
217
old = args [0 ]
173
- new = MXTensor (
174
- old ._scale_e8m0 ,
175
- old ._data .t (),
176
- old ._elem_dtype ,
177
- old ._block_size ,
178
- old ._orig_dtype ,
179
- old ._use_fp4_custom_triton_dequant_kernel ,
180
- old ._gemm_kernel_choice ,
181
- old ._pack_fp6 ,
182
- )
218
+
219
+ if isinstance (old , MXTensor ):
220
+ new = MXTensor (
221
+ old ._scale_e8m0 ,
222
+ old ._data .t (),
223
+ old ._elem_dtype ,
224
+ old ._block_size ,
225
+ old ._orig_dtype ,
226
+ old ._use_fp4_custom_triton_dequant_kernel ,
227
+ old ._gemm_kernel_choice ,
228
+ old ._pack_fp6 ,
229
+ )
230
+ elif isinstance (old , NVFP4Tensor ):
231
+ new = NVFP4Tensor (
232
+ old ._scale_e4m3 ,
233
+ old ._data .t (),
234
+ old ._block_size ,
235
+ old ._orig_dtype ,
236
+ )
237
+ else :
238
+ raise ValueError (f"Unsupported tensor type: { type (old )} " )
183
239
return new
184
240
185
241
@@ -205,25 +261,43 @@ def unwrap(x):
205
261
206
262
@implements ([aten .view .default ])
207
263
def mx_view_op (func , types , args , kwargs ):
208
- data = args [0 ]._data
264
+ tensor = args [0 ]
265
+ data = tensor ._data
209
266
new_size = args [1 ]
210
- if args [0 ]._elem_dtype == torch .float4_e2m1fn_x2 :
211
- # special case fp4 as we pack two elements per byte
267
+
268
+ if isinstance (tensor , MXTensor ):
269
+ if tensor ._elem_dtype == torch .float4_e2m1fn_x2 :
270
+ # special case fp4 as we pack two elements per byte
271
+ new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
272
+ elif (
273
+ tensor ._elem_dtype in [DTYPE_FP6_E3M2 , DTYPE_FP6_E2M3 ] and tensor ._pack_fp6
274
+ ):
275
+ # special case fp6 as we pack 4 elements in 3 bytes
276
+ new_size = tensor_size_hpx3_to_fp6x4 (new_size , data .is_contiguous ())
277
+
278
+ new_data = func (data , new_size , * args [2 :], ** kwargs )
279
+ return MXTensor (
280
+ tensor ._scale_e8m0 ,
281
+ new_data ,
282
+ tensor ._elem_dtype ,
283
+ tensor ._block_size ,
284
+ tensor ._orig_dtype ,
285
+ tensor ._use_fp4_custom_triton_dequant_kernel ,
286
+ tensor ._gemm_kernel_choice ,
287
+ tensor ._pack_fp6 ,
288
+ )
289
+ elif isinstance (tensor , NVFP4Tensor ):
290
+ # NVFP4 is always fp4 packed
212
291
new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
213
- elif args [0 ]._elem_dtype in [DTYPE_FP6_E3M2 , DTYPE_FP6_E2M3 ] and args [0 ]._pack_fp6 :
214
- # special case fp6 as we pack 4 elements in 3 bytes
215
- new_size = tensor_size_hpx3_to_fp6x4 (new_size , data .is_contiguous ())
216
- new_data = func (data , new_size , * args [2 :], ** kwargs )
217
- return MXTensor (
218
- args [0 ]._scale_e8m0 ,
219
- new_data ,
220
- args [0 ]._elem_dtype ,
221
- args [0 ]._block_size ,
222
- args [0 ]._orig_dtype ,
223
- args [0 ]._use_fp4_custom_triton_dequant_kernel ,
224
- args [0 ]._gemm_kernel_choice ,
225
- args [0 ]._pack_fp6 ,
226
- )
292
+ new_data = func (data , new_size , * args [2 :], ** kwargs )
293
+ return NVFP4Tensor (
294
+ tensor ._scale_e4m3 ,
295
+ new_data ,
296
+ tensor ._block_size ,
297
+ tensor ._orig_dtype ,
298
+ )
299
+ else :
300
+ raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
227
301
228
302
229
303
@implements ([aten .slice .Tensor ])
@@ -235,8 +309,15 @@ def mx_slice(func, types, args, kwargs):
235
309
236
310
M , K = x .shape [0 ], x .shape [1 ]
237
311
238
- # TODO why doesn't scale have shape?
239
- scale_shaped = x ._scale_e8m0 .view (M , K // x ._block_size )
312
+ # Handle different scale tensors for different tensor types
313
+ if isinstance (x , MXTensor ):
314
+ scale_tensor = x ._scale_e8m0
315
+ elif isinstance (x , NVFP4Tensor ):
316
+ scale_tensor = x ._scale_e4m3
317
+ else :
318
+ raise ValueError (f"Unsupported tensor type: { type (x )} " )
319
+
320
+ scale_shaped = scale_tensor .view (M , K // x ._block_size )
240
321
241
322
if dim == 0 :
242
323
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
@@ -267,15 +348,14 @@ def mx_slice(func, types, args, kwargs):
267
348
scale_shaped , 1 , start_block , end_block , step
268
349
).flatten ()
269
350
else :
351
+ tensor_name = "MXTensor/NVFP4Tensor"
270
352
raise ValueError (
271
- f"MXTensor only supports slicing along dimensions 0 and 1, got dim={ dim } "
353
+ f"{ tensor_name } only supports slicing along dimensions 0 and 1, got dim={ dim } "
272
354
)
273
355
274
- return return_and_correct_aliasing (
275
- func ,
276
- args ,
277
- kwargs ,
278
- MXTensor (
356
+ # Create appropriate tensor type
357
+ if isinstance (x , MXTensor ):
358
+ result_tensor = MXTensor (
279
359
sliced_scale ,
280
360
sliced_data ,
281
361
x ._elem_dtype ,
@@ -284,7 +364,20 @@ def mx_slice(func, types, args, kwargs):
284
364
x ._use_fp4_custom_triton_dequant_kernel ,
285
365
x ._gemm_kernel_choice ,
286
366
x ._pack_fp6 ,
287
- ),
367
+ )
368
+ else : # NVFP4Tensor
369
+ result_tensor = NVFP4Tensor (
370
+ sliced_scale ,
371
+ sliced_data ,
372
+ x ._block_size ,
373
+ x ._orig_dtype ,
374
+ )
375
+
376
+ return return_and_correct_aliasing (
377
+ func ,
378
+ args ,
379
+ kwargs ,
380
+ result_tensor ,
288
381
)
289
382
290
383
0 commit comments