Skip to content

Commit 988714b

Browse files
committed
Add round_scales_to_power_of_2 option for float quantization
This adds support for rounding scaling factors down to the nearest power of 2 for float quantization, following the pattern established in Float8LinearConfig. Key changes: - Add round_scales_to_power_of_2 parameter to all float quantization configs - Update choose_qparams_affine_floatx and to_scaled_tc_floatx functions to apply power of 2 rounding - Thread the parameter through all relevant function calls in quant_api.py - Maintain backward compatibility with default value of False This helps reduce quantization error by avoiding rounding errors when multiplying/dividing by scaling factors and ensures consistent quantization between forward and backward passes. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> stack-info: PR: #2323, branch: drisspg/stack/67
1 parent 9cd5851 commit 988714b

File tree

5 files changed

+279
-6
lines changed

5 files changed

+279
-6
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
3131
Float8DynamicActivationFloat8WeightConfig,
32+
Float8StaticActivationFloat8WeightConfig,
33+
Float8WeightOnlyConfig,
3234
float8_dynamic_activation_float8_weight,
3335
float8_weight_only,
3436
quantize_,
@@ -630,6 +632,230 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630632
error = compute_error(ref_output, quant_output)
631633
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
632634

635+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
636+
@unittest.skipIf(
637+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
638+
)
639+
def test_power_of_2_scaling_weight_only(self):
640+
"""Test that Float8WeightOnlyConfig with round_scales_to_power_of_2=True works correctly"""
641+
device = "cuda"
642+
dtype = torch.bfloat16
643+
644+
# Create model
645+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
646+
647+
# Test with round_scales_to_power_of_2=True
648+
config = Float8WeightOnlyConfig(round_scales_to_power_of_2=True)
649+
quantized_model = copy.deepcopy(model)
650+
quantize_(quantized_model, config)
651+
652+
# Verify the model was quantized
653+
self.assertTrue(hasattr(quantized_model.weight, "tensor_impl"))
654+
weight_impl = quantized_model.weight.tensor_impl
655+
self.assertTrue(hasattr(weight_impl, "scale"))
656+
657+
# Check that scales are powers of 2
658+
scale = weight_impl.scale.float()
659+
# For power of 2, log2(scale) should be integer
660+
log2_scale = torch.log2(scale)
661+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
662+
self.assertTrue(is_power_of_2, "Scales should be powers of 2")
663+
664+
# Test inference works
665+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
666+
with torch.no_grad():
667+
ref_output = model(input_tensor)
668+
quant_output = quantized_model(input_tensor)
669+
670+
# Verify shapes match
671+
self.assertEqual(ref_output.shape, quant_output.shape)
672+
673+
# Verify reasonable quantization error
674+
error = compute_error(ref_output, quant_output)
675+
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
676+
677+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
678+
@unittest.skipIf(
679+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
680+
)
681+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
682+
def test_power_of_2_scaling_dynamic_activation(self, granularity):
683+
"""Test that Float8DynamicActivationFloat8WeightConfig with round_scales_to_power_of_2=True works correctly"""
684+
device = "cuda"
685+
dtype = torch.bfloat16
686+
687+
# Create model with dimensions that are multiples of 16
688+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
689+
690+
# Test with round_scales_to_power_of_2=True
691+
config = Float8DynamicActivationFloat8WeightConfig(
692+
granularity=granularity, round_scales_to_power_of_2=True
693+
)
694+
quantized_model = copy.deepcopy(model)
695+
quantize_(quantized_model, config)
696+
697+
# Verify the model was quantized
698+
self.assertTrue(hasattr(quantized_model.weight, "original_weight_tensor"))
699+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
700+
self.assertTrue(hasattr(weight_impl, "scale"))
701+
702+
# Check that weight scales are powers of 2
703+
scale = weight_impl.scale.float()
704+
log2_scale = torch.log2(scale)
705+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
706+
self.assertTrue(is_power_of_2, "Weight scales should be powers of 2")
707+
708+
# Test inference works
709+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
710+
with torch.no_grad():
711+
ref_output = model(input_tensor)
712+
quant_output = quantized_model(input_tensor)
713+
714+
# Verify shapes match
715+
self.assertEqual(ref_output.shape, quant_output.shape)
716+
717+
# Verify reasonable quantization error
718+
error = compute_error(ref_output, quant_output)
719+
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
720+
721+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
722+
@unittest.skipIf(
723+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
724+
)
725+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
726+
def test_power_of_2_scaling_static_activation(self, granularity):
727+
"""Test that Float8StaticActivationFloat8WeightConfig with round_scales_to_power_of_2=True works correctly"""
728+
device = "cuda"
729+
dtype = torch.bfloat16
730+
731+
# Create model with dimensions that are multiples of 16
732+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
733+
734+
# Create a scale tensor for static quantization
735+
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
736+
737+
# Test with round_scales_to_power_of_2=True
738+
config = Float8StaticActivationFloat8WeightConfig(
739+
scale=scale, granularity=granularity, round_scales_to_power_of_2=True
740+
)
741+
quantized_model = copy.deepcopy(model)
742+
quantize_(quantized_model, config)
743+
744+
# Verify the model was quantized
745+
self.assertTrue(hasattr(quantized_model.weight, "original_weight_tensor"))
746+
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
747+
self.assertTrue(hasattr(weight_impl, "scale"))
748+
749+
# Check that weight scales are powers of 2
750+
weight_scale = weight_impl.scale.float()
751+
log2_scale = torch.log2(weight_scale)
752+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
753+
self.assertTrue(is_power_of_2, "Weight scales should be powers of 2")
754+
755+
# Test inference works
756+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
757+
with torch.no_grad():
758+
ref_output = model(input_tensor)
759+
quant_output = quantized_model(input_tensor)
760+
761+
# Verify shapes match
762+
self.assertEqual(ref_output.shape, quant_output.shape)
763+
764+
# Verify reasonable quantization error
765+
error = compute_error(ref_output, quant_output)
766+
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
767+
768+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
769+
@unittest.skipIf(
770+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
771+
)
772+
def test_power_of_2_scaling_backward_compatibility(self):
773+
"""Test that default behavior (round_scales_to_power_of_2=False) is unchanged"""
774+
device = "cuda"
775+
dtype = torch.bfloat16
776+
777+
# Create model
778+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
779+
780+
# Test default behavior (should be False)
781+
config_default = Float8WeightOnlyConfig()
782+
quantized_model_default = copy.deepcopy(model)
783+
quantize_(quantized_model_default, config_default)
784+
785+
# Test explicit False
786+
config_false = Float8WeightOnlyConfig(round_scales_to_power_of_2=False)
787+
quantized_model_false = copy.deepcopy(model)
788+
quantize_(quantized_model_false, config_false)
789+
790+
# Get scales from both models
791+
scale_default = quantized_model_default.weight.tensor_impl.scale
792+
scale_false = quantized_model_false.weight.tensor_impl.scale
793+
794+
# They should be identical (backward compatibility)
795+
self.assertTrue(torch.allclose(scale_default, scale_false))
796+
797+
# Test that they produce the same results
798+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
799+
with torch.no_grad():
800+
output_default = quantized_model_default(input_tensor)
801+
output_false = quantized_model_false(input_tensor)
802+
803+
# Outputs should be identical
804+
self.assertTrue(torch.allclose(output_default, output_false))
805+
806+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
807+
@unittest.skipIf(
808+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
809+
)
810+
def test_power_of_2_vs_regular_scaling(self):
811+
"""Test that power of 2 scaling produces different (but reasonable) results compared to regular scaling"""
812+
device = "cuda"
813+
dtype = torch.bfloat16
814+
815+
# Create model
816+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
817+
818+
# Test with regular scaling
819+
config_regular = Float8WeightOnlyConfig(round_scales_to_power_of_2=False)
820+
quantized_model_regular = copy.deepcopy(model)
821+
quantize_(quantized_model_regular, config_regular)
822+
823+
# Test with power of 2 scaling
824+
config_power2 = Float8WeightOnlyConfig(round_scales_to_power_of_2=True)
825+
quantized_model_power2 = copy.deepcopy(model)
826+
quantize_(quantized_model_power2, config_power2)
827+
828+
# Get scales from both models
829+
scale_regular = quantized_model_regular.weight.tensor_impl.scale.float()
830+
scale_power2 = quantized_model_power2.weight.tensor_impl.scale.float()
831+
832+
# Power of 2 scale should be different from regular scale (unless it was already power of 2)
833+
# But the power of 2 scale should be <= regular scale (since we round down)
834+
self.assertTrue(torch.all(scale_power2 <= scale_regular))
835+
836+
# Verify power of 2 scale is actually power of 2
837+
log2_scale = torch.log2(scale_power2)
838+
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
839+
self.assertTrue(is_power_of_2, "Power-of-2 scales should be powers of 2")
840+
841+
# Test that both produce reasonable results
842+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
843+
with torch.no_grad():
844+
ref_output = model(input_tensor)
845+
output_regular = quantized_model_regular(input_tensor)
846+
output_power2 = quantized_model_power2(input_tensor)
847+
848+
# Both should have reasonable quantization error
849+
error_regular = compute_error(ref_output, output_regular)
850+
error_power2 = compute_error(ref_output, output_power2)
851+
852+
self.assertGreater(
853+
error_regular, 15, f"Regular quantization SQNR too low: {error_regular}"
854+
)
855+
self.assertGreater(
856+
error_power2, 15, f"Power-of-2 quantization SQNR too low: {error_power2}"
857+
)
858+
633859

634860
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
635861

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def from_hp_to_floatx(
457457
target_dtype: torch.dtype,
458458
_layout: Layout,
459459
scale_dtype: Optional[torch.dtype] = None,
460+
round_scales_to_power_of_2: bool = False,
460461
):
461462
"""Convert a high precision tensor to a float8 quantized tensor."""
462463
if target_dtype in FP8_TYPES:
@@ -530,6 +531,7 @@ def from_hp_to_fpx(
530531
cls,
531532
input_float: torch.Tensor,
532533
_layout: Layout,
534+
round_scales_to_power_of_2: bool = False,
533535
):
534536
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
535537
from torchao.dtypes.floatx import FloatxTensorCoreLayout
@@ -545,7 +547,9 @@ def from_hp_to_fpx(
545547

546548
ebits, mbits = _layout.ebits, _layout.mbits
547549
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
550+
scale = choose_qparams_affine_floatx(
551+
input_float, ebits, mbits, round_scales_to_power_of_2
552+
)
549553
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
550554
floatx_packed, scale, _ = _layout.post_process(
551555
floatx_unpacked, scale, None, block_size

torchao/dtypes/floatx/floatx_tensor_core_layout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AQTTensorImpl,
2323
Layout,
2424
)
25+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
2526
from torchao.prototype.custom_fp_utils import (
2627
_f32_to_floatx_unpacked,
2728
_floatx_unpacked_to_f32,
@@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
214215

215216

216217
def to_scaled_tc_floatx(
217-
tensor: Tensor, ebits: int, mbits: int
218+
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
218219
) -> Tuple[Tensor, Tensor]:
219220
# _n_ones() is not compatible with torch.compile() due to << operator
220221
# https://github.com/pytorch/pytorch/issues/119152
@@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
230231
dtype = tensor.dtype
231232
tensor = tensor.float()
232233
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
234+
if round_scales_to_power_of_2:
235+
scale = _round_scale_down_to_power_of_2(scale.float())
233236
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
234237
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
235238
return tensor_tc_floatx, scale.to(dtype)

0 commit comments

Comments
 (0)