@@ -33,6 +33,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
33
33
34
34
try :
35
35
from fbgemm_gpu .experimental .gemm .triton_gemm .fp8_gemm import (
36
+ get_fp8_constants as get_fp8_constants ,
36
37
matmul_fp8_row as triton_fp8_row ,
37
38
)
38
39
@@ -52,7 +53,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
52
53
from fbgemm_gpu .experimental .gemm .triton_gemm .fp8_gemm import scale_fp8_row
53
54
54
55
HAS_CUBLAS = True
55
- except ImportError :
56
+ except ( ImportError , IOError , AttributeError ) :
56
57
HAS_CUBLAS = False
57
58
58
59
@@ -79,7 +80,8 @@ def parse_args(args: List[str]) -> argparse.Namespace:
79
80
(16384 , 8192 , 13312 ),
80
81
]
81
82
82
- E4M3_MAX_POS : float = torch .finfo (torch .float8_e4m3fn ).max
83
+ FP8_DTYPE , _ , _ , _ = get_fp8_constants ()
84
+ E4M3_MAX_POS : float = torch .finfo (FP8_DTYPE ).max
83
85
EPS : float = 1e-12
84
86
FP16_MAX_POS : float = torch .finfo (torch .float16 ).max
85
87
@@ -91,7 +93,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91
93
if x .dtype is torch .float16 :
92
94
scale = torch .clamp (scale , max = FP16_MAX_POS )
93
95
xq = torch .clamp (x * scale [:, None ], min = - 1 * E4M3_MAX_POS , max = E4M3_MAX_POS ).to (
94
- torch . float8_e4m3fn
96
+ FP8_DTYPE
95
97
)
96
98
return xq , scale .reciprocal ().to (torch .float32 )
97
99
0 commit comments