Skip to content

Commit 4bb19d6

Browse files
committed
Add round_scales_to_power_of_2 option for float quantization
This adds support for rounding scaling factors down to the nearest power of 2 for float quantization, following the pattern established in Float8LinearConfig. Key changes: - Add round_scales_to_power_of_2 parameter to all float quantization configs - Update choose_qparams_affine_floatx and to_scaled_tc_floatx functions to apply power of 2 rounding - Thread the parameter through all relevant function calls in quant_api.py - Maintain backward compatibility with default value of False This helps reduce quantization error by avoiding rounding errors when multiplying/dividing by scaling factors and ensures consistent quantization between forward and backward passes. stack-info: PR: #2323, branch: drisspg/stack/67
1 parent 0a81ae8 commit 4bb19d6

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
31+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
3132
Float8DynamicActivationFloat8WeightConfig,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3235
float8_dynamic_activation_float8_weight,
3336
float8_weight_only,
3437
quantize_,
@@ -674,6 +677,99 @@ def test_preprocess_scale_3d_reshape(self):
674677
result = preprocess_scale(scale_4d, (2, 2, 2, 8))
675678
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676679
self.assertEqual(result.shape, expected_shape)
680+
def _get_weight_scale_and_impl(self, quantized_model, config_type):
681+
"""Helper to extract weight scale and impl based on config type"""
682+
if config_type == "weight_only":
683+
weight_impl = quantized_model.weight.tensor_impl
684+
return weight_impl.scale.float(), weight_impl
685+
else:
686+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
687+
return weight_impl.scale.float(), weight_impl
688+
689+
def _verify_power_of_2_scales(self, scale):
690+
"""Helper to verify scales are powers of 2"""
691+
log2_scale = torch.log2(scale)
692+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
693+
self.assertTrue(is_power_of_2, "Scales should be powers of 2")
694+
695+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
696+
@unittest.skipIf(
697+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
698+
)
699+
@common_utils.parametrize(
700+
"config_factory,config_type,min_sm,min_error",
701+
[
702+
(
703+
lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True),
704+
"weight_only",
705+
89,
706+
15.0,
707+
),
708+
(
709+
lambda: Float8DynamicActivationFloat8WeightConfig(
710+
granularity=PerTensor(), round_scales_to_power_of_2=True
711+
),
712+
"dynamic",
713+
89,
714+
12.5,
715+
),
716+
(
717+
lambda: Float8DynamicActivationFloat8WeightConfig(
718+
granularity=PerRow(), round_scales_to_power_of_2=True
719+
),
720+
"dynamic",
721+
89,
722+
12.5,
723+
),
724+
(
725+
lambda: Float8StaticActivationFloat8WeightConfig(
726+
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
727+
granularity=PerTensor(),
728+
round_scales_to_power_of_2=True,
729+
),
730+
"static",
731+
89,
732+
15.0,
733+
),
734+
(
735+
lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig(
736+
round_scales_to_power_of_2=True
737+
),
738+
"dynamic",
739+
90,
740+
10.0,
741+
),
742+
],
743+
)
744+
def test_power_of_2_scaling_configs(
745+
self, config_factory, config_type, min_sm, min_error
746+
):
747+
if min_sm == 90 and not is_sm_at_least_90():
748+
self.skipTest("Requires GPU with compute capability >= 9.0")
749+
750+
device = "cuda"
751+
dtype = torch.bfloat16
752+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
753+
754+
config = config_factory()
755+
if isinstance(
756+
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
757+
) and not is_sm_version(9, 0):
758+
self.skipTest("Float8SemiSparse requires compute capability == 9.0")
759+
quantized_model = copy.deepcopy(model)
760+
quantize_(quantized_model, config)
761+
762+
scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type)
763+
self._verify_power_of_2_scales(scale)
764+
765+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
766+
with torch.no_grad():
767+
ref_output = model(input_tensor)
768+
quant_output = quantized_model(input_tensor)
769+
770+
self.assertEqual(ref_output.shape, quant_output.shape)
771+
error = compute_error(ref_output, quant_output)
772+
self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}")
677773

678774

679775
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,17 @@ def from_hp_to_floatx(
457457
target_dtype: torch.dtype,
458458
_layout: Layout,
459459
scale_dtype: Optional[torch.dtype] = None,
460+
round_scales_to_power_of_2: bool = False,
460461
):
461462
"""Convert a high precision tensor to a float8 quantized tensor."""
462463
if target_dtype in FP8_TYPES:
463464
original_shape = input_float.shape
464465
input_float = _layout.pre_process(input_float)
465466
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
467+
input_float,
468+
float8_dtype=target_dtype,
469+
block_size=block_size,
470+
round_scales_to_power_of_2=round_scales_to_power_of_2,
467471
)
468472
data = quantize_affine_float8(input_float, scale, target_dtype)
469473
data, scale, zero_point = _layout.post_process(
@@ -530,6 +534,7 @@ def from_hp_to_fpx(
530534
cls,
531535
input_float: torch.Tensor,
532536
_layout: Layout,
537+
round_scales_to_power_of_2: bool = False,
533538
):
534539
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
535540
from torchao.dtypes.floatx import FloatxTensorCoreLayout
@@ -545,7 +550,9 @@ def from_hp_to_fpx(
545550

546551
ebits, mbits = _layout.ebits, _layout.mbits
547552
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
553+
scale = choose_qparams_affine_floatx(
554+
input_float, ebits, mbits, round_scales_to_power_of_2
555+
)
549556
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
550557
floatx_packed, scale, _ = _layout.post_process(
551558
floatx_unpacked, scale, None, block_size

torchao/dtypes/floatx/floatx_tensor_core_layout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AQTTensorImpl,
2323
Layout,
2424
)
25+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
2526
from torchao.prototype.custom_fp_utils import (
2627
_f32_to_floatx_unpacked,
2728
_floatx_unpacked_to_f32,
@@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
214215

215216

216217
def to_scaled_tc_floatx(
217-
tensor: Tensor, ebits: int, mbits: int
218+
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
218219
) -> Tuple[Tensor, Tensor]:
219220
# _n_ones() is not compatible with torch.compile() due to << operator
220221
# https://github.com/pytorch/pytorch/issues/119152
@@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
230231
dtype = tensor.dtype
231232
tensor = tensor.float()
232233
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
234+
if round_scales_to_power_of_2:
235+
scale = _round_scale_down_to_power_of_2(scale.float())
233236
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
234237
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
235238
return tensor_tc_floatx, scale.to(dtype)

torchao/float8/float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def pad_tensor_for_matmul(
236236
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
237237

238238

239-
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
239+
def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor:
240+
"""Rounds the scale down to the nearest power of 2."""
240241
assert scale.dtype == torch.float32, "scale must be float32 tensor"
241242
return torch.exp2(torch.floor(torch.log2(scale)))

torchao/quantization/quant_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,26 +1281,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
12811281
def _float8_cutlass_quant(
12821282
x: torch.Tensor,
12831283
target_dtype: torch.dtype,
1284+
round_scales_to_power_of_2: bool = False,
12841285
) -> torch.Tensor:
12851286
return to_affine_quantized_floatx(
12861287
x,
12871288
block_size=_get_per_token_block_size(x),
12881289
scale_dtype=torch.float32,
12891290
target_dtype=target_dtype,
12901291
_layout=Float8Layout(mm_config=None),
1292+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12911293
)
12921294

12931295

12941296
def _float8_cutlass_quant_sparse(
12951297
x: torch.Tensor,
12961298
target_dtype: torch.dtype,
1299+
round_scales_to_power_of_2: bool = False,
12971300
) -> (torch.Tensor, torch.Tensor):
12981301
return to_affine_quantized_floatx(
12991302
x,
13001303
block_size=_get_per_token_block_size(x),
13011304
scale_dtype=torch.float32,
13021305
target_dtype=target_dtype,
13031306
_layout=CutlassSemiSparseLayout(),
1307+
round_scales_to_power_of_2=round_scales_to_power_of_2,
13041308
)
13051309

13061310

@@ -1410,13 +1414,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14101414
Args:
14111415
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14121416
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1417+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
14131418
14141419
Note:
14151420
The actual matmul will be computed in original precision of the weight tensor.
14161421
"""
14171422

14181423
weight_dtype: torch.dtype = e4m3_dtype
14191424
set_inductor_config: bool = True
1425+
round_scales_to_power_of_2: bool = False
14201426

14211427

14221428
# for BC
@@ -1433,6 +1439,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14331439
target_dtype=config.weight_dtype,
14341440
scale_dtype=None,
14351441
_layout=Float8Layout(mm_config=None),
1442+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
14361443
)
14371444
return new_weight
14381445

@@ -1461,6 +1468,7 @@ def _input_activation_quant_func_fp8(
14611468
activation_dtype: torch.dtype,
14621469
scale: Optional[torch.Tensor] = None,
14631470
zero_point: Optional[torch.Tensor] = None,
1471+
round_scales_to_power_of_2: bool = False,
14641472
):
14651473
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
14661474
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1481,6 +1489,7 @@ def _input_activation_quant_func_fp8(
14811489
target_dtype=activation_dtype,
14821490
scale_dtype=torch.float32,
14831491
_layout=Float8Layout(mm_config=None), # Config is stored on weight
1492+
round_scales_to_power_of_2=round_scales_to_power_of_2,
14841493
)
14851494
else:
14861495
assert isinstance(activation_granularity, PerTensor), (
@@ -1538,6 +1547,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15381547
only PerTensor and PerRow are supported.
15391548
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15401549
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1550+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
15411551
15421552
"""
15431553

@@ -1546,6 +1556,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15461556
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
15471557
mm_config: Optional[Float8MMConfig] = None
15481558
set_inductor_config: bool = True
1559+
round_scales_to_power_of_2: bool = False
15491560

15501561
def __post_init__(self):
15511562
if self.mm_config is None:
@@ -1589,12 +1600,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15891600
target_dtype=weight_dtype,
15901601
scale_dtype=torch.float32,
15911602
_layout=Float8Layout(mm_config=mm_config),
1603+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
15921604
)
15931605

15941606
input_quant_func = _input_activation_quant_func_fp8
15951607
input_quant_kwargs = {
15961608
"activation_granularity": activation_granularity,
15971609
"activation_dtype": activation_dtype,
1610+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
15981611
}
15991612

16001613
quantized_weight = to_linear_activation_quantized(
@@ -1634,11 +1647,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
16341647
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
16351648
`activation_dtype`: data type for quantized activation tensor.
16361649
`weight_dtype`: data type for quantized weight tensor.
1650+
`round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
16371651
"""
16381652

16391653
layout: Layout = CutlassSemiSparseLayout()
16401654
activation_dtype: torch.dtype = e5m2_dtype
16411655
weight_dtype: torch.dtype = e4m3_dtype
1656+
round_scales_to_power_of_2: bool = False
16421657

16431658

16441659
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
@@ -1657,11 +1672,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
16571672
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
16581673
)
16591674

1660-
weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
1675+
weight = _float8_cutlass_quant_sparse(
1676+
weight, weight_dtype, config.round_scales_to_power_of_2
1677+
)
16611678
weight = to_linear_activation_quantized(
16621679
weight,
16631680
_float8_cutlass_quant,
1664-
quant_kwargs={"target_dtype": activation_dtype},
1681+
quant_kwargs={
1682+
"target_dtype": activation_dtype,
1683+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
1684+
},
16651685
)
16661686

16671687
module.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -1680,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16801700
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
16811701
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
16821702
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1703+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
16831704
"""
16841705

16851706
scale: torch.Tensor
@@ -1690,6 +1711,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16901711
] = None
16911712
mm_config: Optional[Float8MMConfig] = None
16921713
set_inductor_config: bool = True
1714+
round_scales_to_power_of_2: bool = False
16931715

16941716
def __post_init__(self):
16951717
if self.mm_config is None:
@@ -1733,12 +1755,14 @@ def _float8_static_activation_float8_weight_transform(
17331755
target_dtype=weight_dtype,
17341756
scale_dtype=torch.float32,
17351757
_layout=Float8Layout(mm_config=mm_config),
1758+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
17361759
)
17371760

17381761
input_quant_func = _input_activation_quant_func_fp8
17391762
input_quant_kwargs = {
17401763
"activation_granularity": activation_granularity,
17411764
"activation_dtype": activation_dtype,
1765+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
17421766
}
17431767

17441768
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(

0 commit comments

Comments
 (0)