Skip to content

Commit 9ed84e7

Browse files
committed
Add round_scales_to_power_of_2 option for float quantization
This adds support for rounding scaling factors down to the nearest power of 2 for float quantization, following the pattern established in Float8LinearConfig. Key changes: - Add round_scales_to_power_of_2 parameter to all float quantization configs - Update choose_qparams_affine_floatx and to_scaled_tc_floatx functions to apply power of 2 rounding - Thread the parameter through all relevant function calls in quant_api.py - Maintain backward compatibility with default value of False This helps reduce quantization error by avoiding rounding errors when multiplying/dividing by scaling factors and ensures consistent quantization between forward and backward passes. stack-info: PR: #2323, branch: drisspg/stack/67
1 parent 2ec325d commit 9ed84e7

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
31+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
3132
Float8DynamicActivationFloat8WeightConfig,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3235
float8_dynamic_activation_float8_weight,
3336
float8_weight_only,
3437
quantize_,
@@ -674,6 +677,99 @@ def test_preprocess_scale_3d_reshape(self):
674677
result = preprocess_scale(scale_4d, (2, 2, 2, 8))
675678
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676679
self.assertEqual(result.shape, expected_shape)
680+
def _get_weight_scale_and_impl(self, quantized_model, config_type):
681+
"""Helper to extract weight scale and impl based on config type"""
682+
if config_type == "weight_only":
683+
weight_impl = quantized_model.weight.tensor_impl
684+
return weight_impl.scale.float(), weight_impl
685+
else:
686+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
687+
return weight_impl.scale.float(), weight_impl
688+
689+
def _verify_power_of_2_scales(self, scale):
690+
"""Helper to verify scales are powers of 2"""
691+
log2_scale = torch.log2(scale)
692+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
693+
self.assertTrue(is_power_of_2, "Scales should be powers of 2")
694+
695+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
696+
@unittest.skipIf(
697+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
698+
)
699+
@common_utils.parametrize(
700+
"config_factory,config_type,min_sm,min_error",
701+
[
702+
(
703+
lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True),
704+
"weight_only",
705+
89,
706+
15.0,
707+
),
708+
(
709+
lambda: Float8DynamicActivationFloat8WeightConfig(
710+
granularity=PerTensor(), round_scales_to_power_of_2=True
711+
),
712+
"dynamic",
713+
89,
714+
12.5,
715+
),
716+
(
717+
lambda: Float8DynamicActivationFloat8WeightConfig(
718+
granularity=PerRow(), round_scales_to_power_of_2=True
719+
),
720+
"dynamic",
721+
89,
722+
12.5,
723+
),
724+
(
725+
lambda: Float8StaticActivationFloat8WeightConfig(
726+
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
727+
granularity=PerTensor(),
728+
round_scales_to_power_of_2=True,
729+
),
730+
"static",
731+
89,
732+
15.0,
733+
),
734+
(
735+
lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig(
736+
round_scales_to_power_of_2=True
737+
),
738+
"dynamic",
739+
90,
740+
10.0,
741+
),
742+
],
743+
)
744+
def test_power_of_2_scaling_configs(
745+
self, config_factory, config_type, min_sm, min_error
746+
):
747+
if min_sm == 90 and not is_sm_at_least_90():
748+
self.skipTest("Requires GPU with compute capability >= 9.0")
749+
750+
device = "cuda"
751+
dtype = torch.bfloat16
752+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
753+
754+
config = config_factory()
755+
if isinstance(
756+
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
757+
) and not is_sm_version(9, 0):
758+
self.skipTest("Float8SemiSparse requires compute capability == 9.0")
759+
quantized_model = copy.deepcopy(model)
760+
quantize_(quantized_model, config)
761+
762+
scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type)
763+
self._verify_power_of_2_scales(scale)
764+
765+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
766+
with torch.no_grad():
767+
ref_output = model(input_tensor)
768+
quant_output = quantized_model(input_tensor)
769+
770+
self.assertEqual(ref_output.shape, quant_output.shape)
771+
error = compute_error(ref_output, quant_output)
772+
self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}")
677773

678774

679775
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,17 @@ def from_hp_to_floatx(
457457
target_dtype: torch.dtype,
458458
_layout: Layout,
459459
scale_dtype: Optional[torch.dtype] = None,
460+
round_scales_to_power_of_2: bool = False,
460461
):
461462
"""Convert a high precision tensor to a float8 quantized tensor."""
462463
if target_dtype in FP8_TYPES:
463464
original_shape = input_float.shape
464465
input_float = _layout.pre_process(input_float)
465466
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
467+
input_float,
468+
float8_dtype=target_dtype,
469+
block_size=block_size,
470+
round_scales_to_power_of_2=round_scales_to_power_of_2,
467471
)
468472
data = quantize_affine_float8(input_float, scale, target_dtype)
469473
data, scale, zero_point = _layout.post_process(
@@ -530,6 +534,7 @@ def from_hp_to_fpx(
530534
cls,
531535
input_float: torch.Tensor,
532536
_layout: Layout,
537+
round_scales_to_power_of_2: bool = False,
533538
):
534539
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
535540
from torchao.dtypes.floatx import FloatxTensorCoreLayout
@@ -545,7 +550,9 @@ def from_hp_to_fpx(
545550

546551
ebits, mbits = _layout.ebits, _layout.mbits
547552
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
553+
scale = choose_qparams_affine_floatx(
554+
input_float, ebits, mbits, round_scales_to_power_of_2
555+
)
549556
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
550557
floatx_packed, scale, _ = _layout.post_process(
551558
floatx_unpacked, scale, None, block_size

torchao/dtypes/floatx/floatx_tensor_core_layout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AQTTensorImpl,
2323
Layout,
2424
)
25+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
2526
from torchao.prototype.custom_fp_utils import (
2627
_f32_to_floatx_unpacked,
2728
_floatx_unpacked_to_f32,
@@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
214215

215216

216217
def to_scaled_tc_floatx(
217-
tensor: Tensor, ebits: int, mbits: int
218+
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
218219
) -> Tuple[Tensor, Tensor]:
219220
# _n_ones() is not compatible with torch.compile() due to << operator
220221
# https://github.com/pytorch/pytorch/issues/119152
@@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
230231
dtype = tensor.dtype
231232
tensor = tensor.float()
232233
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
234+
if round_scales_to_power_of_2:
235+
scale = _round_scale_down_to_power_of_2(scale.float())
233236
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
234237
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
235238
return tensor_tc_floatx, scale.to(dtype)

torchao/float8/float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def pad_tensor_for_matmul(
236236
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
237237

238238

239-
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
239+
def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor:
240+
"""Rounds the scale down to the nearest power of 2."""
240241
assert scale.dtype == torch.float32, "scale must be float32 tensor"
241242
return torch.exp2(torch.floor(torch.log2(scale)))

torchao/quantization/quant_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,26 +1274,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
12741274
def _float8_cutlass_quant(
12751275
x: torch.Tensor,
12761276
target_dtype: torch.dtype,
1277+
round_scales_to_power_of_2: bool = False,
12771278
) -> torch.Tensor:
12781279
return to_affine_quantized_floatx(
12791280
x,
12801281
block_size=_get_per_token_block_size(x),
12811282
scale_dtype=torch.float32,
12821283
target_dtype=target_dtype,
12831284
_layout=Float8Layout(mm_config=None),
1285+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12841286
)
12851287

12861288

12871289
def _float8_cutlass_quant_sparse(
12881290
x: torch.Tensor,
12891291
target_dtype: torch.dtype,
1292+
round_scales_to_power_of_2: bool = False,
12901293
) -> (torch.Tensor, torch.Tensor):
12911294
return to_affine_quantized_floatx(
12921295
x,
12931296
block_size=_get_per_token_block_size(x),
12941297
scale_dtype=torch.float32,
12951298
target_dtype=target_dtype,
12961299
_layout=CutlassSemiSparseLayout(),
1300+
round_scales_to_power_of_2=round_scales_to_power_of_2,
12971301
)
12981302

12991303

@@ -1403,13 +1407,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14031407
Args:
14041408
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14051409
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1410+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
14061411
14071412
Note:
14081413
The actual matmul will be computed in original precision of the weight tensor.
14091414
"""
14101415

14111416
weight_dtype: torch.dtype = e4m3_dtype
14121417
set_inductor_config: bool = True
1418+
round_scales_to_power_of_2: bool = False
14131419

14141420

14151421
# for BC
@@ -1426,6 +1432,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14261432
target_dtype=config.weight_dtype,
14271433
scale_dtype=None,
14281434
_layout=Float8Layout(mm_config=None),
1435+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
14291436
)
14301437
return new_weight
14311438

@@ -1454,6 +1461,7 @@ def _input_activation_quant_func_fp8(
14541461
activation_dtype: torch.dtype,
14551462
scale: Optional[torch.Tensor] = None,
14561463
zero_point: Optional[torch.Tensor] = None,
1464+
round_scales_to_power_of_2: bool = False,
14571465
):
14581466
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
14591467
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1474,6 +1482,7 @@ def _input_activation_quant_func_fp8(
14741482
target_dtype=activation_dtype,
14751483
scale_dtype=torch.float32,
14761484
_layout=Float8Layout(mm_config=None), # Config is stored on weight
1485+
round_scales_to_power_of_2=round_scales_to_power_of_2,
14771486
)
14781487
else:
14791488
assert isinstance(activation_granularity, PerTensor), (
@@ -1531,6 +1540,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15311540
only PerTensor and PerRow are supported.
15321541
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15331542
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1543+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
15341544
15351545
"""
15361546

@@ -1539,6 +1549,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15391549
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
15401550
mm_config: Optional[Float8MMConfig] = None
15411551
set_inductor_config: bool = True
1552+
round_scales_to_power_of_2: bool = False
15421553

15431554
def __post_init__(self):
15441555
if self.mm_config is None:
@@ -1582,12 +1593,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15821593
target_dtype=weight_dtype,
15831594
scale_dtype=torch.float32,
15841595
_layout=Float8Layout(mm_config=mm_config),
1596+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
15851597
)
15861598

15871599
input_quant_func = _input_activation_quant_func_fp8
15881600
input_quant_kwargs = {
15891601
"activation_granularity": activation_granularity,
15901602
"activation_dtype": activation_dtype,
1603+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
15911604
}
15921605

15931606
quantized_weight = to_linear_activation_quantized(
@@ -1627,11 +1640,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
16271640
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
16281641
`activation_dtype`: data type for quantized activation tensor.
16291642
`weight_dtype`: data type for quantized weight tensor.
1643+
`round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
16301644
"""
16311645

16321646
layout: Layout = CutlassSemiSparseLayout()
16331647
activation_dtype: torch.dtype = e5m2_dtype
16341648
weight_dtype: torch.dtype = e4m3_dtype
1649+
round_scales_to_power_of_2: bool = False
16351650

16361651

16371652
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
@@ -1650,11 +1665,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
16501665
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
16511666
)
16521667

1653-
weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
1668+
weight = _float8_cutlass_quant_sparse(
1669+
weight, weight_dtype, config.round_scales_to_power_of_2
1670+
)
16541671
weight = to_linear_activation_quantized(
16551672
weight,
16561673
_float8_cutlass_quant,
1657-
quant_kwargs={"target_dtype": activation_dtype},
1674+
quant_kwargs={
1675+
"target_dtype": activation_dtype,
1676+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
1677+
},
16581678
)
16591679

16601680
module.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -1673,6 +1693,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16731693
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
16741694
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
16751695
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1696+
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
16761697
"""
16771698

16781699
scale: torch.Tensor
@@ -1683,6 +1704,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
16831704
] = None
16841705
mm_config: Optional[Float8MMConfig] = None
16851706
set_inductor_config: bool = True
1707+
round_scales_to_power_of_2: bool = False
16861708

16871709
def __post_init__(self):
16881710
if self.mm_config is None:
@@ -1726,12 +1748,14 @@ def _float8_static_activation_float8_weight_transform(
17261748
target_dtype=weight_dtype,
17271749
scale_dtype=torch.float32,
17281750
_layout=Float8Layout(mm_config=mm_config),
1751+
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
17291752
)
17301753

17311754
input_quant_func = _input_activation_quant_func_fp8
17321755
input_quant_kwargs = {
17331756
"activation_granularity": activation_granularity,
17341757
"activation_dtype": activation_dtype,
1758+
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
17351759
}
17361760

17371761
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(

0 commit comments

Comments
 (0)