Skip to content

Commit 195187b

Browse files
committed
Add round_scales_to_power_of_2 option for float quantization
This adds support for rounding scaling factors down to the nearest power of 2 for float quantization, following the pattern established in Float8LinearConfig. Key changes: - Add round_scales_to_power_of_2 parameter to all float quantization configs - Update choose_qparams_affine_floatx and to_scaled_tc_floatx functions to apply power of 2 rounding - Thread the parameter through all relevant function calls in quant_api.py - Maintain backward compatibility with default value of False This helps reduce quantization error by avoiding rounding errors when multiplying/dividing by scaling factors and ensures consistent quantization between forward and backward passes. stack-info: PR: #2323, branch: drisspg/stack/67
1 parent 9cd5851 commit 195187b

File tree

6 files changed

+154
-14
lines changed

6 files changed

+154
-14
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
31+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
3132
Float8DynamicActivationFloat8WeightConfig,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3235
float8_dynamic_activation_float8_weight,
3336
float8_weight_only,
3437
quantize_,
@@ -630,6 +633,100 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630633
error = compute_error(ref_output, quant_output)
631634
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
632635

636+
def _get_weight_scale_and_impl(self, quantized_model, config_type):
637+
"""Helper to extract weight scale and impl based on config type"""
638+
if config_type == "weight_only":
639+
weight_impl = quantized_model.weight.tensor_impl
640+
return weight_impl.scale.float(), weight_impl
641+
else:
642+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
643+
return weight_impl.scale.float(), weight_impl
644+
645+
def _verify_power_of_2_scales(self, scale):
646+
"""Helper to verify scales are powers of 2"""
647+
log2_scale = torch.log2(scale)
648+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
649+
self.assertTrue(is_power_of_2, "Scales should be powers of 2")
650+
651+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
652+
@unittest.skipIf(
653+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
654+
)
655+
@common_utils.parametrize(
656+
"config_factory,config_type,min_sm,min_error",
657+
[
658+
(
659+
lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True),
660+
"weight_only",
661+
89,
662+
15.0,
663+
),
664+
(
665+
lambda: Float8DynamicActivationFloat8WeightConfig(
666+
granularity=PerTensor(), round_scales_to_power_of_2=True
667+
),
668+
"dynamic",
669+
89,
670+
12.5,
671+
),
672+
(
673+
lambda: Float8DynamicActivationFloat8WeightConfig(
674+
granularity=PerRow(), round_scales_to_power_of_2=True
675+
),
676+
"dynamic",
677+
89,
678+
12.5,
679+
),
680+
(
681+
lambda: Float8StaticActivationFloat8WeightConfig(
682+
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
683+
granularity=PerTensor(),
684+
round_scales_to_power_of_2=True,
685+
),
686+
"static",
687+
89,
688+
15.0,
689+
),
690+
(
691+
lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig(
692+
round_scales_to_power_of_2=True
693+
),
694+
"dynamic",
695+
90,
696+
10.0,
697+
),
698+
],
699+
)
700+
def test_power_of_2_scaling_configs(
701+
self, config_factory, config_type, min_sm, min_error
702+
):
703+
if min_sm == 90 and not is_sm_at_least_90():
704+
self.skipTest("Requires GPU with compute capability >= 9.0")
705+
706+
device = "cuda"
707+
dtype = torch.bfloat16
708+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
709+
710+
config = config_factory()
711+
if isinstance(
712+
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
713+
) and not is_sm_version(9, 0):
714+
self.skipTest("Float8SemiSparse requires compute capability == 9.0")
715+
quantized_model = copy.deepcopy(model)
716+
quantize_(quantized_model, config)
717+
718+
scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type)
719+
self._verify_power_of_2_scales(scale)
720+
721+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
722+
with torch.no_grad():
723+
ref_output = model(input_tensor)
724+
quant_output = quantized_model(input_tensor)
725+
726+
self.assertEqual(ref_output.shape, quant_output.shape)
727+
error = compute_error(ref_output, quant_output)
728+
self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}")
729+
633730

634731
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
635732

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,17 @@ def from_hp_to_floatx(
457457
target_dtype: torch.dtype,
458458
_layout: Layout,
459459
scale_dtype: Optional[torch.dtype] = None,
460+
round_scales_to_power_of_2: bool = False,
460461
):
461462
"""Convert a high precision tensor to a float8 quantized tensor."""
462463
if target_dtype in FP8_TYPES:
463464
original_shape = input_float.shape
464465
input_float = _layout.pre_process(input_float)
465466
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
467+
input_float,
468+
float8_dtype=target_dtype,
469+
block_size=block_size,
470+
round_scales_to_power_of_2=round_scales_to_power_of_2,
467471
)
468472
data = quantize_affine_float8(input_float, scale, target_dtype)
469473
data, scale, zero_point = _layout.post_process(
@@ -530,6 +534,7 @@ def from_hp_to_fpx(
530534
cls,
531535
input_float: torch.Tensor,
532536
_layout: Layout,
537+
round_scales_to_power_of_2: bool = False,
533538
):
534539
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
535540
from torchao.dtypes.floatx import FloatxTensorCoreLayout
@@ -545,7 +550,9 @@ def from_hp_to_fpx(
545550

546551
ebits, mbits = _layout.ebits, _layout.mbits
547552
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
553+
scale = choose_qparams_affine_floatx(
554+
input_float, ebits, mbits, round_scales_to_power_of_2
555+
)
549556
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
550557
floatx_packed, scale, _ = _layout.post_process(
551558
floatx_unpacked, scale, None, block_size

torchao/dtypes/floatx/floatx_tensor_core_layout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AQTTensorImpl,
2323
Layout,
2424
)
25+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
2526
from torchao.prototype.custom_fp_utils import (
2627
_f32_to_floatx_unpacked,
2728
_floatx_unpacked_to_f32,
@@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
214215

215216

216217
def to_scaled_tc_floatx(
217-
tensor: Tensor, ebits: int, mbits: int
218+
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
218219
) -> Tuple[Tensor, Tensor]:
219220
# _n_ones() is not compatible with torch.compile() due to << operator
220221
# https://github.com/pytorch/pytorch/issues/119152
@@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
230231
dtype = tensor.dtype
231232
tensor = tensor.float()
232233
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
234+
if round_scales_to_power_of_2:
235+
scale = _round_scale_down_to_power_of_2(scale.float())
233236
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
234237
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
235238
return tensor_tc_floatx, scale.to(dtype)

torchao/float8/float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def pad_tensor_for_matmul(
236236
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
237237

238238

239-
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
239+
def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor:
240+
"""Rounds the scale down to the nearest power of 2."""
240241
assert scale.dtype == torch.float32, "scale must be float32 tensor"
241242
return torch.exp2(torch.floor(torch.log2(scale)))

torchao/quantization/quant_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,26 +1270,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
12701270
def _float8_cutlass_quant(
12711271
x: torch.Tensor,
12721272
target_dtype: torch.dtype,
1273+
round_scales_to_power_of_2: bool = False,
12731274
) -> torch.Tensor:
12741275
return to_affine_quantized_floatx(
12751276
x,
12761277
block_size=_get_per_token_block_size(x),
12771278
scale_dtype=torch.float32,
12781279
target_dtype=target_dtype,
12791280
_layout=Float8Layout(mm_config=None),
1281+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12801282
)
12811283

12821284

12831285
def _float8_cutlass_quant_sparse(
12841286
x: torch.Tensor,
12851287
target_dtype: torch.dtype,
1288+
round_scales_to_power_of_2: bool = False,
12861289
) -> (torch.Tensor, torch.Tensor):
12871290
return to_affine_quantized_floatx(
12881291
x,
12891292
block_size=_get_per_token_block_size(x),
12901293
scale_dtype=torch.float32,
12911294
target_dtype=target_dtype,
12921295
_layout=CutlassSemiSparseLayout(),
1296+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12931297
)
12941298

12951299

@@ -1399,13 +1403,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
13991403
Args:
14001404
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14011405
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1406+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
14021407
14031408
Note:
14041409
The actual matmul will be computed in original precision of the weight tensor.
14051410
"""
14061411

14071412
weight_dtype: torch.dtype = e4m3_dtype
14081413
set_inductor_config: bool = True
1414+
round_scales_to_power_of_2: bool = False
14091415

14101416

14111417
# for BC
@@ -1422,6 +1428,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14221428
target_dtype=config.weight_dtype,
14231429
scale_dtype=None,
14241430
_layout=Float8Layout(mm_config=None),
1431+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
14251432
)
14261433
return new_weight
14271434

@@ -1450,6 +1457,7 @@ def _input_activation_quant_func_fp8(
14501457
activation_dtype: torch.dtype,
14511458
scale: Optional[torch.Tensor] = None,
14521459
zero_point: Optional[torch.Tensor] = None,
1460+
round_scales_to_power_of_2: bool = False,
14531461
):
14541462
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
14551463
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1470,6 +1478,7 @@ def _input_activation_quant_func_fp8(
14701478
target_dtype=activation_dtype,
14711479
scale_dtype=torch.float32,
14721480
_layout=Float8Layout(mm_config=None), # Config is stored on weight
1481+
round_scales_to_power_of_2=round_scales_to_power_of_2,
14731482
)
14741483
else:
14751484
assert isinstance(activation_granularity, PerTensor), (
@@ -1527,6 +1536,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15271536
only PerTensor and PerRow are supported.
15281537
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15291538
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1539+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
15301540
15311541
"""
15321542

@@ -1535,6 +1545,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15351545
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
15361546
mm_config: Optional[Float8MMConfig] = None
15371547
set_inductor_config: bool = True
1548+
round_scales_to_power_of_2: bool = False
15381549

15391550
def __post_init__(self):
15401551
if self.mm_config is None:
@@ -1578,12 +1589,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15781589
target_dtype=weight_dtype,
15791590
scale_dtype=torch.float32,
15801591
_layout=Float8Layout(mm_config=mm_config),
1592+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
15811593
)
15821594

15831595
input_quant_func = _input_activation_quant_func_fp8
15841596
input_quant_kwargs = {
15851597
"activation_granularity": activation_granularity,
15861598
"activation_dtype": activation_dtype,
1599+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
15871600
}
15881601

15891602
quantized_weight = to_linear_activation_quantized(
@@ -1623,11 +1636,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
16231636
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
16241637
`activation_dtype`: data type for quantized activation tensor.
16251638
`weight_dtype`: data type for quantized weight tensor.
1639+
`round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
16261640
"""
16271641

16281642
layout: Layout = CutlassSemiSparseLayout()
16291643
activation_dtype: torch.dtype = e5m2_dtype
16301644
weight_dtype: torch.dtype = e4m3_dtype
1645+
round_scales_to_power_of_2: bool = False
16311646

16321647

16331648
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
@@ -1646,11 +1661,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
16461661
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
16471662
)
16481663

1649-
weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
1664+
weight = _float8_cutlass_quant_sparse(
1665+
weight, weight_dtype, config.round_scales_to_power_of_2
1666+
)
16501667
weight = to_linear_activation_quantized(
16511668
weight,
16521669
_float8_cutlass_quant,
1653-
quant_kwargs={"target_dtype": activation_dtype},
1670+
quant_kwargs={
1671+
"target_dtype": activation_dtype,
1672+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
1673+
},
16541674
)
16551675

16561676
module.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -1669,6 +1689,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16691689
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
16701690
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
16711691
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1692+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
16721693
"""
16731694

16741695
scale: torch.Tensor
@@ -1679,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16791700
] = None
16801701
mm_config: Optional[Float8MMConfig] = None
16811702
set_inductor_config: bool = True
1703+
round_scales_to_power_of_2: bool = False
16821704

16831705
def __post_init__(self):
16841706
if self.mm_config is None:
@@ -1722,12 +1744,14 @@ def _float8_static_activation_float8_weight_transform(
17221744
target_dtype=weight_dtype,
17231745
scale_dtype=torch.float32,
17241746
_layout=Float8Layout(mm_config=mm_config),
1747+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
17251748
)
17261749

17271750
input_quant_func = _input_activation_quant_func_fp8
17281751
input_quant_kwargs = {
17291752
"activation_granularity": activation_granularity,
17301753
"activation_dtype": activation_dtype,
1754+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
17311755
}
17321756

17331757
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(

0 commit comments

Comments
 (0)