|
20 | 20 | _validate_gemm_kernel_choice,
|
21 | 21 | )
|
22 | 22 | from torchao.prototype.mx_formats.mx_tensor import MXTensor
|
| 23 | +from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor |
23 | 24 | from torchao.quantization.quant_api import to_linear_activation_quantized
|
24 | 25 | from torchao.quantization.transform_module import (
|
25 | 26 | register_quantize_module_handler,
|
@@ -63,7 +64,7 @@ class MXFPInferenceConfig(AOBaseConfig):
|
63 | 64 |
|
64 | 65 | block_size: int = 32
|
65 | 66 |
|
66 |
| - # Dtypes for Input and Weights |
| 67 | + # Dtypes for Input and Weights, supports Fp8 and Fp4 formats |
67 | 68 | activation_dtype: torch.dtype = torch.float8_e4m3fn
|
68 | 69 | weight_dtype: torch.dtype = torch.float8_e4m3fn
|
69 | 70 |
|
@@ -151,7 +152,106 @@ def _mx_inference_linear_transform(
|
151 | 152 | return module
|
152 | 153 |
|
153 | 154 |
|
| 155 | +@dataclass |
| 156 | +class NVFP4InferenceConfig(AOBaseConfig): |
| 157 | + """ |
| 158 | + NVIDIA FP4 (NVFP4) Inference Quantization Configuration |
| 159 | +
|
| 160 | + This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales. |
| 161 | + It provides defaults optimized for NVFP4: |
| 162 | + - Data: float4_e2m1fn_x2 |
| 163 | + - Scales: float8_e4m3fn (UE4M3) |
| 164 | + - Block size: 16 (required for NVFP4) |
| 165 | + - CUBLAS kernel (optimized for VEC16_UE4M3) |
| 166 | + """ |
| 167 | + |
| 168 | + block_size: int = 16 # NVFP4 requires block size 16 |
| 169 | + |
| 170 | + # NVFP4 uses FP4 data |
| 171 | + activation_dtype: torch.dtype = torch.float4_e2m1fn_x2 |
| 172 | + weight_dtype: torch.dtype = torch.float4_e2m1fn_x2 |
| 173 | + |
| 174 | + # NVFP4 uses E4M3 scales |
| 175 | + scale_dtype: torch.dtype = torch.float8_e4m3fn |
| 176 | + |
| 177 | + # CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support |
| 178 | + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS |
| 179 | + |
| 180 | + # Set some magic perf settings |
| 181 | + set_inductor_config: bool = False |
| 182 | + |
| 183 | + def __post_init__(self): |
| 184 | + # Validate NVFP4 constraints |
| 185 | + assert self.activation_dtype == torch.float4_e2m1fn_x2, ( |
| 186 | + f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}" |
| 187 | + ) |
| 188 | + assert self.weight_dtype == torch.float4_e2m1fn_x2, ( |
| 189 | + f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}" |
| 190 | + ) |
| 191 | + assert self.scale_dtype == torch.float8_e4m3fn, ( |
| 192 | + f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}" |
| 193 | + ) |
| 194 | + assert self.block_size == 16, ( |
| 195 | + f"NVFP4 requires block_size=16, got {self.block_size}" |
| 196 | + ) |
| 197 | + |
| 198 | + |
| 199 | +def _input_activation_quant_func_nvfp4( |
| 200 | + x: torch.Tensor, |
| 201 | + block_size: int = 16, |
| 202 | + scale: Optional[torch.Tensor] = None, |
| 203 | +): |
| 204 | + """NVFP4-specific activation quantization function""" |
| 205 | + # TODO: scale for static quant |
| 206 | + activation = NVFP4Tensor.to_nvfp4( |
| 207 | + x, |
| 208 | + block_size=block_size, |
| 209 | + ) |
| 210 | + return activation |
| 211 | + |
| 212 | + |
| 213 | +@register_quantize_module_handler(NVFP4InferenceConfig) |
| 214 | +def _nvfp4_inference_linear_transform( |
| 215 | + module: torch.nn.Module, config: NVFP4InferenceConfig |
| 216 | +): |
| 217 | + """Quantization handler for NVFP4InferenceConfig""" |
| 218 | + assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines" |
| 219 | + if config.set_inductor_config: |
| 220 | + torchao.quantization.utils.recommended_inductor_config_setter() |
| 221 | + |
| 222 | + weight = module.weight |
| 223 | + assert weight.dtype == torch.bfloat16, ( |
| 224 | + f"Only supporting bf16 out dtype for now, got {weight.dtype}" |
| 225 | + ) |
| 226 | + |
| 227 | + # Convert weight to NVFP4 Tensor |
| 228 | + quantized_weight = NVFP4Tensor.to_nvfp4( |
| 229 | + weight, |
| 230 | + block_size=config.block_size, |
| 231 | + ) |
| 232 | + |
| 233 | + input_quant_func = _input_activation_quant_func_nvfp4 |
| 234 | + input_quant_kwargs = { |
| 235 | + "block_size": config.block_size, |
| 236 | + "scale": None, |
| 237 | + } |
| 238 | + |
| 239 | + quantized_weight = to_linear_activation_quantized( |
| 240 | + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs |
| 241 | + ) |
| 242 | + |
| 243 | + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) |
| 244 | + module.extra_repr = types.MethodType(_linear_extra_repr, module) |
| 245 | + return module |
| 246 | + |
| 247 | + |
154 | 248 | if TORCH_VERSION_AT_LEAST_2_5:
|
155 | 249 | torch.serialization.add_safe_globals(
|
156 |
| - [MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp] |
| 250 | + [ |
| 251 | + MXTensor, |
| 252 | + NVFP4Tensor, |
| 253 | + MXGemmKernelChoice, |
| 254 | + _input_activation_quant_func_mxfp, |
| 255 | + _input_activation_quant_func_nvfp4, |
| 256 | + ] |
157 | 257 | )
|
0 commit comments