3333)
3434from torchao .prototype .mx_formats .mx_tensor import ( # noqa: E501
3535 MXTensor ,
36+ NVFP4Tensor ,
3637 tensor_size_hp_to_fp4x2 ,
3738 tensor_size_hpx3_to_fp6x4 ,
3839)
@@ -93,8 +94,8 @@ def _addmm_mx_dispatch(
9394 M , K , N = a .shape [0 ], a .shape [1 ], b .shape [1 ]
9495 assert a ._data .is_contiguous ()
9596 assert b ._data .t ().is_contiguous ()
96- assert a ._block_size == 32 , f"Invalid block size { a ._block_size } "
97- assert b ._block_size == 32 , f"Invalid block size { b ._block_size } "
97+ assert a ._block_size in [ 16 , 32 ] , f"Invalid block size { a ._block_size } "
98+ assert b ._block_size in [ 16 , 32 ] , f"Invalid block size { b ._block_size } "
9899
99100 a_scale = a ._scale_e8m0 .view (M , K // a ._block_size )
100101 b_scale = b ._scale_e8m0 .view (N , K // b ._block_size )
@@ -144,42 +145,97 @@ def _addmm_mx_dispatch(
144145 return res
145146
146147
148+ def _addmm_nvfp4_dispatch (
149+ a : NVFP4Tensor , b : NVFP4Tensor , aten_op , bias : Optional [torch .Tensor ] = None
150+ ) -> torch .Tensor :
151+ """
152+ Core implementation for NVFP4Tensor operations
153+ Uses E4M3 scales and always uses CUBLAS for FP4 operations
154+ """
155+ # NVFP4 operations with E4M3 scales
156+ M , K , N = a .shape [0 ], a .shape [1 ], b .shape [1 ]
157+ assert a ._data .is_contiguous ()
158+ assert b ._data .t ().is_contiguous ()
159+ assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
160+ assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
161+
162+ # NVFP4 uses E4M3 scales, not E8M0
163+ a_scale = a ._scale_e4m3 .view (M , K // a ._block_size )
164+ b_scale = b ._scale_e4m3 .view (N , K // b ._block_size )
165+ a_scale_block = to_blocked (a_scale )
166+ b_scale_block = to_blocked (b_scale )
167+
168+ # NVFP4 always uses CUBLAS with VEC16_UE4M3 scale mode
169+ res = torch ._scaled_mm (
170+ a ._data ,
171+ b ._data ,
172+ a_scale_block .view (torch .float8_e4m3fn ),
173+ b_scale_block .view (torch .float8_e4m3fn ),
174+ bias = bias ,
175+ out_dtype = torch .bfloat16 ,
176+ )
177+
178+ return res
179+
180+
147181@implements ([aten .mm .default , aten .matmul .default ])
148182def mx_mm (func , types , args , kwargs ):
149183 a = args [0 ]
150184 b = args [1 ]
151- assert isinstance (a , MXTensor ) and isinstance (b , MXTensor )
152185
153- return _addmm_mx_dispatch (a , b , func )
186+ # Handle both MXTensor and NVFP4Tensor
187+ if isinstance (a , MXTensor ) and isinstance (b , MXTensor ):
188+ return _addmm_mx_dispatch (a , b , func )
189+ elif isinstance (a , NVFP4Tensor ) and isinstance (b , NVFP4Tensor ):
190+ return _addmm_nvfp4_dispatch (a , b , func )
191+ else :
192+ raise ValueError (f"Unsupported tensor types: { type (a )} , { type (b )} " )
154193
155194
156195@implements ([aten .addmm .default ])
157196def mx_addmm (func , types , args , kwargs ):
158- assert (
159- isinstance (args [0 ], torch .Tensor )
160- and isinstance (args [1 ], MXTensor )
161- and isinstance (args [2 ], MXTensor )
162- )
163197 bias = args [0 ]
164198 a = args [1 ]
165199 b = args [2 ]
166- return _addmm_mx_dispatch (a , b , func , bias = bias )
200+
201+ assert isinstance (bias , torch .Tensor ), (
202+ f"Bias must be torch.Tensor, got { type (bias )} "
203+ )
204+
205+ # Handle both MXTensor and NVFP4Tensor
206+ if isinstance (a , MXTensor ) and isinstance (b , MXTensor ):
207+ return _addmm_mx_dispatch (a , b , func , bias = bias )
208+ elif isinstance (a , NVFP4Tensor ) and isinstance (b , NVFP4Tensor ):
209+ return _addmm_nvfp4_dispatch (a , b , func , bias = bias )
210+ else :
211+ raise ValueError (f"Unsupported tensor types: { type (a )} , { type (b )} " )
167212
168213
169214@implements ([aten .t .default ])
170215def mx_t (func , types , args , kwargs ):
171216 # For now, only transpose(input, 0, 1) is supported.
172217 old = args [0 ]
173- new = MXTensor (
174- old ._scale_e8m0 ,
175- old ._data .t (),
176- old ._elem_dtype ,
177- old ._block_size ,
178- old ._orig_dtype ,
179- old ._use_fp4_custom_triton_dequant_kernel ,
180- old ._gemm_kernel_choice ,
181- old ._pack_fp6 ,
182- )
218+
219+ if isinstance (old , MXTensor ):
220+ new = MXTensor (
221+ old ._scale_e8m0 ,
222+ old ._data .t (),
223+ old ._elem_dtype ,
224+ old ._block_size ,
225+ old ._orig_dtype ,
226+ old ._use_fp4_custom_triton_dequant_kernel ,
227+ old ._gemm_kernel_choice ,
228+ old ._pack_fp6 ,
229+ )
230+ elif isinstance (old , NVFP4Tensor ):
231+ new = NVFP4Tensor (
232+ old ._scale_e4m3 ,
233+ old ._data .t (),
234+ old ._block_size ,
235+ old ._orig_dtype ,
236+ )
237+ else :
238+ raise ValueError (f"Unsupported tensor type: { type (old )} " )
183239 return new
184240
185241
@@ -205,25 +261,43 @@ def unwrap(x):
205261
206262@implements ([aten .view .default ])
207263def mx_view_op (func , types , args , kwargs ):
208- data = args [0 ]._data
264+ tensor = args [0 ]
265+ data = tensor ._data
209266 new_size = args [1 ]
210- if args [0 ]._elem_dtype == torch .float4_e2m1fn_x2 :
211- # special case fp4 as we pack two elements per byte
267+
268+ if isinstance (tensor , MXTensor ):
269+ if tensor ._elem_dtype == torch .float4_e2m1fn_x2 :
270+ # special case fp4 as we pack two elements per byte
271+ new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
272+ elif (
273+ tensor ._elem_dtype in [DTYPE_FP6_E3M2 , DTYPE_FP6_E2M3 ] and tensor ._pack_fp6
274+ ):
275+ # special case fp6 as we pack 4 elements in 3 bytes
276+ new_size = tensor_size_hpx3_to_fp6x4 (new_size , data .is_contiguous ())
277+
278+ new_data = func (data , new_size , * args [2 :], ** kwargs )
279+ return MXTensor (
280+ tensor ._scale_e8m0 ,
281+ new_data ,
282+ tensor ._elem_dtype ,
283+ tensor ._block_size ,
284+ tensor ._orig_dtype ,
285+ tensor ._use_fp4_custom_triton_dequant_kernel ,
286+ tensor ._gemm_kernel_choice ,
287+ tensor ._pack_fp6 ,
288+ )
289+ elif isinstance (tensor , NVFP4Tensor ):
290+ # NVFP4 is always fp4 packed
212291 new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
213- elif args [0 ]._elem_dtype in [DTYPE_FP6_E3M2 , DTYPE_FP6_E2M3 ] and args [0 ]._pack_fp6 :
214- # special case fp6 as we pack 4 elements in 3 bytes
215- new_size = tensor_size_hpx3_to_fp6x4 (new_size , data .is_contiguous ())
216- new_data = func (data , new_size , * args [2 :], ** kwargs )
217- return MXTensor (
218- args [0 ]._scale_e8m0 ,
219- new_data ,
220- args [0 ]._elem_dtype ,
221- args [0 ]._block_size ,
222- args [0 ]._orig_dtype ,
223- args [0 ]._use_fp4_custom_triton_dequant_kernel ,
224- args [0 ]._gemm_kernel_choice ,
225- args [0 ]._pack_fp6 ,
226- )
292+ new_data = func (data , new_size , * args [2 :], ** kwargs )
293+ return NVFP4Tensor (
294+ tensor ._scale_e4m3 ,
295+ new_data ,
296+ tensor ._block_size ,
297+ tensor ._orig_dtype ,
298+ )
299+ else :
300+ raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
227301
228302
229303@implements ([aten .slice .Tensor ])
@@ -235,8 +309,15 @@ def mx_slice(func, types, args, kwargs):
235309
236310 M , K = x .shape [0 ], x .shape [1 ]
237311
238- # TODO why doesn't scale have shape?
239- scale_shaped = x ._scale_e8m0 .view (M , K // x ._block_size )
312+ # Handle different scale tensors for different tensor types
313+ if isinstance (x , MXTensor ):
314+ scale_tensor = x ._scale_e8m0
315+ elif isinstance (x , NVFP4Tensor ):
316+ scale_tensor = x ._scale_e4m3
317+ else :
318+ raise ValueError (f"Unsupported tensor type: { type (x )} " )
319+
320+ scale_shaped = scale_tensor .view (M , K // x ._block_size )
240321
241322 if dim == 0 :
242323 # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
@@ -267,15 +348,14 @@ def mx_slice(func, types, args, kwargs):
267348 scale_shaped , 1 , start_block , end_block , step
268349 ).flatten ()
269350 else :
351+ tensor_name = "MXTensor/NVFP4Tensor"
270352 raise ValueError (
271- f"MXTensor only supports slicing along dimensions 0 and 1, got dim={ dim } "
353+ f"{ tensor_name } only supports slicing along dimensions 0 and 1, got dim={ dim } "
272354 )
273355
274- return return_and_correct_aliasing (
275- func ,
276- args ,
277- kwargs ,
278- MXTensor (
356+ # Create appropriate tensor type
357+ if isinstance (x , MXTensor ):
358+ result_tensor = MXTensor (
279359 sliced_scale ,
280360 sliced_data ,
281361 x ._elem_dtype ,
@@ -284,7 +364,20 @@ def mx_slice(func, types, args, kwargs):
284364 x ._use_fp4_custom_triton_dequant_kernel ,
285365 x ._gemm_kernel_choice ,
286366 x ._pack_fp6 ,
287- ),
367+ )
368+ else : # NVFP4Tensor
369+ result_tensor = NVFP4Tensor (
370+ sliced_scale ,
371+ sliced_data ,
372+ x ._block_size ,
373+ x ._orig_dtype ,
374+ )
375+
376+ return return_and_correct_aliasing (
377+ func ,
378+ args ,
379+ kwargs ,
380+ result_tensor ,
288381 )
289382
290383
0 commit comments