Skip to content

Commit aec0821

Browse files
authored
[float8] Add fnuz fp8 dtypes to Float8Layout (#2351)
This should give us AMD perf on vLLM. With Phi-4-mini-instruct on MI300x with TorchAO FP8 rowwise quant on the MLP I see the following, which is about a 5% speedup: ``` Avg latency: 1.080369415456274 seconds 10% percentile latency: 1.075335633114446 seconds 25% percentile latency: 1.0811904482543468 seconds 50% percentile latency: 1.082176529977005 seconds 75% percentile latency: 1.0826280051842332 seconds 90% percentile latency: 1.0831242799758911 seconds 99% percentile latency: 1.0836151059856638 seconds ``` For comparison, here is the baseline Phi-4-mini-instruct on MI300x: ``` Avg latency: 1.148340248184589 seconds 10% percentile latency: 1.1391733552212826 seconds 25% percentile latency: 1.14905939399614 seconds 50% percentile latency: 1.150204271019902 seconds 75% percentile latency: 1.1523984443047084 seconds 90% percentile latency: 1.1536207939614542 seconds 99% percentile latency: 1.1548575214319863 seconds ``` Previously, these checks were failing on the unsigned zero ROCm fp8 dtypes, causing us to call `.dequantize()` and then do a bfloat16 mm, which was slower than the bf16 baseline (~2s).
1 parent a250418 commit aec0821

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/dtypes/floatx/float8_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
363363
return (
364364
isinstance(aqt, AffineQuantizedTensor)
365365
and isinstance(aqt._layout, Float8Layout)
366-
and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
366+
and _is_float8_type(aqt.tensor_impl.dtype)
367367
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
368368
)
369369

@@ -442,7 +442,7 @@ def _linear_fp_act_fp8_weight_check(
442442
# weight is float8 quantized affine quantized tensor
443443
isinstance(weight_tensor, AffineQuantizedTensor)
444444
and isinstance(weight_tensor._layout, Float8Layout)
445-
and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
445+
and _is_float8_type(weight_tensor.tensor_impl.dtype)
446446
and (
447447
weight_tensor.shape == weight_tensor.block_size
448448
or _is_rowwise_scaled(weight_tensor)

0 commit comments

Comments
 (0)