Skip to content

Commit f2932b7

Browse files
htyufacebook-github-bot
authored andcommitted
Enable fp8 rowwise on AMDGPU (#2483)
Summary: Pull Request resolved: #2483 Reviewed By: karthik-man Differential Revision: D63726031 fbshipit-source-id: dc410e503f918d83362fb38005ac4a6db5dc1e68
1 parent 4445aa2 commit f2932b7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchbenchmark/operators/fp8_gemm_rowwise/operator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
3333

3434
try:
3535
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
36+
get_fp8_constants as get_fp8_constants,
3637
matmul_fp8_row as triton_fp8_row,
3738
)
3839

@@ -52,7 +53,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
5253
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import scale_fp8_row
5354

5455
HAS_CUBLAS = True
55-
except ImportError:
56+
except (ImportError, IOError, AttributeError):
5657
HAS_CUBLAS = False
5758

5859

@@ -79,7 +80,8 @@ def parse_args(args: List[str]) -> argparse.Namespace:
7980
(16384, 8192, 13312),
8081
]
8182

82-
E4M3_MAX_POS: float = torch.finfo(torch.float8_e4m3fn).max
83+
FP8_DTYPE, _, _, _ = get_fp8_constants()
84+
E4M3_MAX_POS: float = torch.finfo(FP8_DTYPE).max
8385
EPS: float = 1e-12
8486
FP16_MAX_POS: float = torch.finfo(torch.float16).max
8587

@@ -91,7 +93,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
9193
if x.dtype is torch.float16:
9294
scale = torch.clamp(scale, max=FP16_MAX_POS)
9395
xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to(
94-
torch.float8_e4m3fn
96+
FP8_DTYPE
9597
)
9698
return xq, scale.reciprocal().to(torch.float32)
9799

0 commit comments

Comments
 (0)