You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
f"scale_dtype {scale_dtype} is only supported with weight_dtype {weight_dtype} and activation_dtype {activation_dtype}, got weight_dtype {weight_dtype} and activation_dtype {activation_dtype}"
50
+
)
51
+
assertblock_size==16, f"For NVFP4, block_size must be 16, got {block_size}"
52
+
53
+
30
54
# Note: This API is extra prototype and will change in the future
31
55
@dataclass
32
56
classMXFPInferenceConfig(AOBaseConfig):
@@ -61,12 +85,16 @@ class MXFPInferenceConfig(AOBaseConfig):
61
85
- MXTensor in torchao.prototype.mx_formats.mx_tensor
62
86
"""
63
87
64
-
block_size: int=32
88
+
block_size: Union[Literal[32], Literal[16]]=32
65
89
66
-
# Dtypes for Input and Weights
90
+
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
67
91
activation_dtype: torch.dtype=torch.float8_e4m3fn
68
92
weight_dtype: torch.dtype=torch.float8_e4m3fn
69
93
94
+
# Supports float8_e4m3fn, float8_e8m0fnu
95
+
# e8m0 for MX and e4m3 for NVFP4 on Cuda compatable devices
0 commit comments