Skip to content

Commit 4d77255

Browse files
committed
Add round_scales_to_power_of_2 option for float quantization
This adds support for rounding scaling factors down to the nearest power of 2 for float quantization, following the pattern established in Float8LinearConfig. Key changes: - Add round_scales_to_power_of_2 parameter to all float quantization configs - Update choose_qparams_affine_floatx and to_scaled_tc_floatx functions to apply power of 2 rounding - Thread the parameter through all relevant function calls in quant_api.py - Maintain backward compatibility with default value of False This helps reduce quantization error by avoiding rounding errors when multiplying/dividing by scaling factors and ensures consistent quantization between forward and backward passes. stack-info: PR: #2323, branch: drisspg/stack/67
1 parent 282d04f commit 4d77255

File tree

7 files changed

+158
-17
lines changed

7 files changed

+158
-17
lines changed

.github/workflows/float8_test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ jobs:
2525
include:
2626
- name: SM-89
2727
runs-on: linux.g6.4xlarge.experimental.nvidia.gpu
28-
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
28+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu128'
2929
gpu-arch-type: "cuda"
30-
gpu-arch-version: "12.6"
30+
gpu-arch-version: "12.8"
3131
- name: H100
3232
runs-on: linux.aws.h100
33-
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126'
33+
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128'
3434
gpu-arch-type: "cuda"
35-
gpu-arch-version: "12.4"
35+
gpu-arch-version: "12.8"
3636
permissions:
3737
id-token: write
3838
contents: read

test/dtypes/test_affine_quantized_float.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
31+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
3132
Float8DynamicActivationFloat8WeightConfig,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3235
float8_dynamic_activation_float8_weight,
3336
float8_weight_only,
3437
quantize_,
@@ -675,6 +678,100 @@ def test_preprocess_scale_3d_reshape(self):
675678
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676679
self.assertEqual(result.shape, expected_shape)
677680

681+
def _get_weight_scale_and_impl(self, quantized_model, config_type):
682+
"""Helper to extract weight scale and impl based on config type"""
683+
if config_type == "weight_only":
684+
weight_impl = quantized_model.weight.tensor_impl
685+
return weight_impl.scale.float(), weight_impl
686+
else:
687+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
688+
return weight_impl.scale.float(), weight_impl
689+
690+
def _verify_power_of_2_scales(self, scale):
691+
"""Helper to verify scales are powers of 2"""
692+
log2_scale = torch.log2(scale)
693+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
694+
self.assertTrue(is_power_of_2, "Scales should be powers of 2")
695+
696+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
697+
@unittest.skipIf(
698+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
699+
)
700+
@common_utils.parametrize(
701+
"config_factory,config_type,min_sm,min_error",
702+
[
703+
(
704+
lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True),
705+
"weight_only",
706+
89,
707+
15.0,
708+
),
709+
(
710+
lambda: Float8DynamicActivationFloat8WeightConfig(
711+
granularity=PerTensor(), round_scales_to_power_of_2=True
712+
),
713+
"dynamic",
714+
89,
715+
12.5,
716+
),
717+
(
718+
lambda: Float8DynamicActivationFloat8WeightConfig(
719+
granularity=PerRow(), round_scales_to_power_of_2=True
720+
),
721+
"dynamic",
722+
89,
723+
12.5,
724+
),
725+
(
726+
lambda: Float8StaticActivationFloat8WeightConfig(
727+
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
728+
granularity=PerTensor(),
729+
round_scales_to_power_of_2=True,
730+
),
731+
"static",
732+
89,
733+
15.0,
734+
),
735+
(
736+
lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig(
737+
round_scales_to_power_of_2=True
738+
),
739+
"dynamic",
740+
90,
741+
10.0,
742+
),
743+
],
744+
)
745+
def test_power_of_2_scaling_configs(
746+
self, config_factory, config_type, min_sm, min_error
747+
):
748+
if min_sm == 90 and not is_sm_at_least_90():
749+
self.skipTest("Requires GPU with compute capability >= 9.0")
750+
751+
device = "cuda"
752+
dtype = torch.bfloat16
753+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
754+
755+
config = config_factory()
756+
if isinstance(
757+
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
758+
) and not is_sm_version(9, 0):
759+
self.skipTest("Float8SemiSparse requires compute capability == 9.0")
760+
quantized_model = copy.deepcopy(model)
761+
quantize_(quantized_model, config)
762+
763+
scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type)
764+
self._verify_power_of_2_scales(scale)
765+
766+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
767+
with torch.no_grad():
768+
ref_output = model(input_tensor)
769+
quant_output = quantized_model(input_tensor)
770+
771+
self.assertEqual(ref_output.shape, quant_output.shape)
772+
error = compute_error(ref_output, quant_output)
773+
self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}")
774+
678775

679776
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
680777

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,17 @@ def from_hp_to_floatx(
457457
target_dtype: torch.dtype,
458458
_layout: Layout,
459459
scale_dtype: Optional[torch.dtype] = None,
460+
round_scales_to_power_of_2: bool = False,
460461
):
461462
"""Convert a high precision tensor to a float8 quantized tensor."""
462463
if target_dtype in FP8_TYPES:
463464
original_shape = input_float.shape
464465
input_float = _layout.pre_process(input_float)
465466
scale = _choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
467+
input_float,
468+
float8_dtype=target_dtype,
469+
block_size=block_size,
470+
round_scales_to_power_of_2=round_scales_to_power_of_2,
467471
)
468472
data = _quantize_affine_float8(input_float, scale, target_dtype)
469473
data, scale, zero_point = _layout.post_process(
@@ -530,6 +534,7 @@ def from_hp_to_fpx(
530534
cls,
531535
input_float: torch.Tensor,
532536
_layout: Layout,
537+
round_scales_to_power_of_2: bool = False,
533538
):
534539
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
535540
from torchao.dtypes.floatx import FloatxTensorCoreLayout
@@ -545,7 +550,9 @@ def from_hp_to_fpx(
545550

546551
ebits, mbits = _layout.ebits, _layout.mbits
547552
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = _choose_qparams_affine_floatx(input_float, ebits, mbits)
553+
scale = _choose_qparams_affine_floatx(
554+
input_float, ebits, mbits, round_scales_to_power_of_2
555+
)
549556
floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits)
550557
floatx_packed, scale, _ = _layout.post_process(
551558
floatx_unpacked, scale, None, block_size

torchao/dtypes/floatx/floatx_tensor_core_layout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AQTTensorImpl,
2323
Layout,
2424
)
25+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
2526
from torchao.prototype.custom_fp_utils import (
2627
_f32_to_floatx_unpacked,
2728
_floatx_unpacked_to_f32,
@@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
214215

215216

216217
def to_scaled_tc_floatx(
217-
tensor: Tensor, ebits: int, mbits: int
218+
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
218219
) -> Tuple[Tensor, Tensor]:
219220
# _n_ones() is not compatible with torch.compile() due to << operator
220221
# https://github.com/pytorch/pytorch/issues/119152
@@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
230231
dtype = tensor.dtype
231232
tensor = tensor.float()
232233
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
234+
if round_scales_to_power_of_2:
235+
scale = _round_scale_down_to_power_of_2(scale.float())
233236
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
234237
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
235238
return tensor_tc_floatx, scale.to(dtype)

torchao/float8/float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def pad_tensor_for_matmul(
236236
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
237237

238238

239-
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
239+
def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor:
240+
"""Rounds the scale down to the nearest power of 2."""
240241
assert scale.dtype == torch.float32, "scale must be float32 tensor"
241242
return torch.exp2(torch.floor(torch.log2(scale)))

torchao/quantization/quant_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,26 +1281,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
12811281
def _float8_cutlass_quant(
12821282
x: torch.Tensor,
12831283
target_dtype: torch.dtype,
1284+
round_scales_to_power_of_2: bool = False,
12841285
) -> torch.Tensor:
12851286
return to_affine_quantized_floatx(
12861287
x,
12871288
block_size=_get_per_token_block_size(x),
12881289
scale_dtype=torch.float32,
12891290
target_dtype=target_dtype,
12901291
_layout=Float8Layout(mm_config=None),
1292+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12911293
)
12921294

12931295

12941296
def _float8_cutlass_quant_sparse(
12951297
x: torch.Tensor,
12961298
target_dtype: torch.dtype,
1299+
round_scales_to_power_of_2: bool = False,
12971300
) -> (torch.Tensor, torch.Tensor):
12981301
return to_affine_quantized_floatx(
12991302
x,
13001303
block_size=_get_per_token_block_size(x),
13011304
scale_dtype=torch.float32,
13021305
target_dtype=target_dtype,
13031306
_layout=CutlassSemiSparseLayout(),
1307+
round_scales_to_power_of_2=round_scales_to_power_of_2,
13041308
)
13051309

13061310

@@ -1410,13 +1414,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14101414
Args:
14111415
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14121416
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1417+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
14131418
14141419
Note:
14151420
The actual matmul will be computed in original precision of the weight tensor.
14161421
"""
14171422

14181423
weight_dtype: torch.dtype = e4m3_dtype
14191424
set_inductor_config: bool = True
1425+
round_scales_to_power_of_2: bool = False
14201426

14211427

14221428
# for BC
@@ -1433,6 +1439,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14331439
target_dtype=config.weight_dtype,
14341440
scale_dtype=None,
14351441
_layout=Float8Layout(mm_config=None),
1442+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
14361443
)
14371444
return new_weight
14381445

@@ -1461,6 +1468,7 @@ def _input_activation_quant_func_fp8(
14611468
activation_dtype: torch.dtype,
14621469
scale: Optional[torch.Tensor] = None,
14631470
zero_point: Optional[torch.Tensor] = None,
1471+
round_scales_to_power_of_2: bool = False,
14641472
):
14651473
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
14661474
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1481,6 +1489,7 @@ def _input_activation_quant_func_fp8(
14811489
target_dtype=activation_dtype,
14821490
scale_dtype=torch.float32,
14831491
_layout=Float8Layout(mm_config=None), # Config is stored on weight
1492+
round_scales_to_power_of_2=round_scales_to_power_of_2,
14841493
)
14851494
else:
14861495
assert isinstance(activation_granularity, PerTensor), (
@@ -1538,6 +1547,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15381547
only PerTensor and PerRow are supported.
15391548
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15401549
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1550+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
15411551
15421552
"""
15431553

@@ -1546,6 +1556,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15461556
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
15471557
mm_config: Optional[Float8MMConfig] = None
15481558
set_inductor_config: bool = True
1559+
round_scales_to_power_of_2: bool = False
15491560

15501561
def __post_init__(self):
15511562
if self.mm_config is None:
@@ -1589,12 +1600,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15891600
target_dtype=weight_dtype,
15901601
scale_dtype=torch.float32,
15911602
_layout=Float8Layout(mm_config=mm_config),
1603+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
15921604
)
15931605

15941606
input_quant_func = _input_activation_quant_func_fp8
15951607
input_quant_kwargs = {
15961608
"activation_granularity": activation_granularity,
15971609
"activation_dtype": activation_dtype,
1610+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
15981611
}
15991612

16001613
quantized_weight = to_linear_activation_quantized(
@@ -1634,11 +1647,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
16341647
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
16351648
`activation_dtype`: data type for quantized activation tensor.
16361649
`weight_dtype`: data type for quantized weight tensor.
1650+
`round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
16371651
"""
16381652

16391653
layout: Layout = CutlassSemiSparseLayout()
16401654
activation_dtype: torch.dtype = e5m2_dtype
16411655
weight_dtype: torch.dtype = e4m3_dtype
1656+
round_scales_to_power_of_2: bool = False
16421657

16431658

16441659
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
@@ -1657,11 +1672,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
16571672
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
16581673
)
16591674

1660-
weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
1675+
weight = _float8_cutlass_quant_sparse(
1676+
weight, weight_dtype, config.round_scales_to_power_of_2
1677+
)
16611678
weight = to_linear_activation_quantized(
16621679
weight,
16631680
_float8_cutlass_quant,
1664-
quant_kwargs={"target_dtype": activation_dtype},
1681+
quant_kwargs={
1682+
"target_dtype": activation_dtype,
1683+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
1684+
},
16651685
)
16661686

16671687
module.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -1680,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16801700
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
16811701
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
16821702
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1703+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
16831704
"""
16841705

16851706
scale: torch.Tensor
@@ -1690,6 +1711,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16901711
] = None
16911712
mm_config: Optional[Float8MMConfig] = None
16921713
set_inductor_config: bool = True
1714+
round_scales_to_power_of_2: bool = False
16931715

16941716
def __post_init__(self):
16951717
if self.mm_config is None:
@@ -1733,12 +1755,14 @@ def _float8_static_activation_float8_weight_transform(
17331755
target_dtype=weight_dtype,
17341756
scale_dtype=torch.float32,
17351757
_layout=Float8Layout(mm_config=mm_config),
1758+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
17361759
)
17371760

17381761
input_quant_func = _input_activation_quant_func_fp8
17391762
input_quant_kwargs = {
17401763
"activation_granularity": activation_granularity,
17411764
"activation_dtype": activation_dtype,
1765+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
17421766
}
17431767

17441768
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(

0 commit comments

Comments
 (0)