|
29 | 29 | from torchao.float8.float8_utils import compute_error
|
30 | 30 | from torchao.quantization import (
|
31 | 31 | Float8DynamicActivationFloat8WeightConfig,
|
| 32 | + Float8StaticActivationFloat8WeightConfig, |
| 33 | + Float8WeightOnlyConfig, |
32 | 34 | float8_dynamic_activation_float8_weight,
|
33 | 35 | float8_weight_only,
|
34 | 36 | quantize_,
|
@@ -630,6 +632,230 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
|
630 | 632 | error = compute_error(ref_output, quant_output)
|
631 | 633 | self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
|
632 | 634 |
|
| 635 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 636 | + @unittest.skipIf( |
| 637 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 638 | + ) |
| 639 | + def test_power_of_2_scaling_weight_only(self): |
| 640 | + """Test that Float8WeightOnlyConfig with round_scales_to_power_of_2=True works correctly""" |
| 641 | + device = "cuda" |
| 642 | + dtype = torch.bfloat16 |
| 643 | + |
| 644 | + # Create model |
| 645 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 646 | + |
| 647 | + # Test with round_scales_to_power_of_2=True |
| 648 | + config = Float8WeightOnlyConfig(round_scales_to_power_of_2=True) |
| 649 | + quantized_model = copy.deepcopy(model) |
| 650 | + quantize_(quantized_model, config) |
| 651 | + |
| 652 | + # Verify the model was quantized |
| 653 | + self.assertTrue(hasattr(quantized_model.weight, "tensor_impl")) |
| 654 | + weight_impl = quantized_model.weight.tensor_impl |
| 655 | + self.assertTrue(hasattr(weight_impl, "scale")) |
| 656 | + |
| 657 | + # Check that scales are powers of 2 |
| 658 | + scale = weight_impl.scale.float() |
| 659 | + # For power of 2, log2(scale) should be integer |
| 660 | + log2_scale = torch.log2(scale) |
| 661 | + is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6) |
| 662 | + self.assertTrue(is_power_of_2, "Scales should be powers of 2") |
| 663 | + |
| 664 | + # Test inference works |
| 665 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 666 | + with torch.no_grad(): |
| 667 | + ref_output = model(input_tensor) |
| 668 | + quant_output = quantized_model(input_tensor) |
| 669 | + |
| 670 | + # Verify shapes match |
| 671 | + self.assertEqual(ref_output.shape, quant_output.shape) |
| 672 | + |
| 673 | + # Verify reasonable quantization error |
| 674 | + error = compute_error(ref_output, quant_output) |
| 675 | + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") |
| 676 | + |
| 677 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 678 | + @unittest.skipIf( |
| 679 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 680 | + ) |
| 681 | + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
| 682 | + def test_power_of_2_scaling_dynamic_activation(self, granularity): |
| 683 | + """Test that Float8DynamicActivationFloat8WeightConfig with round_scales_to_power_of_2=True works correctly""" |
| 684 | + device = "cuda" |
| 685 | + dtype = torch.bfloat16 |
| 686 | + |
| 687 | + # Create model with dimensions that are multiples of 16 |
| 688 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 689 | + |
| 690 | + # Test with round_scales_to_power_of_2=True |
| 691 | + config = Float8DynamicActivationFloat8WeightConfig( |
| 692 | + granularity=granularity, round_scales_to_power_of_2=True |
| 693 | + ) |
| 694 | + quantized_model = copy.deepcopy(model) |
| 695 | + quantize_(quantized_model, config) |
| 696 | + |
| 697 | + # Verify the model was quantized |
| 698 | + self.assertTrue(hasattr(quantized_model.weight, "original_weight_tensor")) |
| 699 | + weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl |
| 700 | + self.assertTrue(hasattr(weight_impl, "scale")) |
| 701 | + |
| 702 | + # Check that weight scales are powers of 2 |
| 703 | + scale = weight_impl.scale.float() |
| 704 | + log2_scale = torch.log2(scale) |
| 705 | + is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6) |
| 706 | + self.assertTrue(is_power_of_2, "Weight scales should be powers of 2") |
| 707 | + |
| 708 | + # Test inference works |
| 709 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 710 | + with torch.no_grad(): |
| 711 | + ref_output = model(input_tensor) |
| 712 | + quant_output = quantized_model(input_tensor) |
| 713 | + |
| 714 | + # Verify shapes match |
| 715 | + self.assertEqual(ref_output.shape, quant_output.shape) |
| 716 | + |
| 717 | + # Verify reasonable quantization error |
| 718 | + error = compute_error(ref_output, quant_output) |
| 719 | + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") |
| 720 | + |
| 721 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 722 | + @unittest.skipIf( |
| 723 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 724 | + ) |
| 725 | + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
| 726 | + def test_power_of_2_scaling_static_activation(self, granularity): |
| 727 | + """Test that Float8StaticActivationFloat8WeightConfig with round_scales_to_power_of_2=True works correctly""" |
| 728 | + device = "cuda" |
| 729 | + dtype = torch.bfloat16 |
| 730 | + |
| 731 | + # Create model with dimensions that are multiples of 16 |
| 732 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 733 | + |
| 734 | + # Create a scale tensor for static quantization |
| 735 | + scale = torch.tensor(1.0, dtype=torch.float32, device=device) |
| 736 | + |
| 737 | + # Test with round_scales_to_power_of_2=True |
| 738 | + config = Float8StaticActivationFloat8WeightConfig( |
| 739 | + scale=scale, granularity=granularity, round_scales_to_power_of_2=True |
| 740 | + ) |
| 741 | + quantized_model = copy.deepcopy(model) |
| 742 | + quantize_(quantized_model, config) |
| 743 | + |
| 744 | + # Verify the model was quantized |
| 745 | + self.assertTrue(hasattr(quantized_model.weight, "original_weight_tensor")) |
| 746 | + weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl |
| 747 | + self.assertTrue(hasattr(weight_impl, "scale")) |
| 748 | + |
| 749 | + # Check that weight scales are powers of 2 |
| 750 | + weight_scale = weight_impl.scale.float() |
| 751 | + log2_scale = torch.log2(weight_scale) |
| 752 | + is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6) |
| 753 | + self.assertTrue(is_power_of_2, "Weight scales should be powers of 2") |
| 754 | + |
| 755 | + # Test inference works |
| 756 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 757 | + with torch.no_grad(): |
| 758 | + ref_output = model(input_tensor) |
| 759 | + quant_output = quantized_model(input_tensor) |
| 760 | + |
| 761 | + # Verify shapes match |
| 762 | + self.assertEqual(ref_output.shape, quant_output.shape) |
| 763 | + |
| 764 | + # Verify reasonable quantization error |
| 765 | + error = compute_error(ref_output, quant_output) |
| 766 | + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") |
| 767 | + |
| 768 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 769 | + @unittest.skipIf( |
| 770 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 771 | + ) |
| 772 | + def test_power_of_2_scaling_backward_compatibility(self): |
| 773 | + """Test that default behavior (round_scales_to_power_of_2=False) is unchanged""" |
| 774 | + device = "cuda" |
| 775 | + dtype = torch.bfloat16 |
| 776 | + |
| 777 | + # Create model |
| 778 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 779 | + |
| 780 | + # Test default behavior (should be False) |
| 781 | + config_default = Float8WeightOnlyConfig() |
| 782 | + quantized_model_default = copy.deepcopy(model) |
| 783 | + quantize_(quantized_model_default, config_default) |
| 784 | + |
| 785 | + # Test explicit False |
| 786 | + config_false = Float8WeightOnlyConfig(round_scales_to_power_of_2=False) |
| 787 | + quantized_model_false = copy.deepcopy(model) |
| 788 | + quantize_(quantized_model_false, config_false) |
| 789 | + |
| 790 | + # Get scales from both models |
| 791 | + scale_default = quantized_model_default.weight.tensor_impl.scale |
| 792 | + scale_false = quantized_model_false.weight.tensor_impl.scale |
| 793 | + |
| 794 | + # They should be identical (backward compatibility) |
| 795 | + self.assertTrue(torch.allclose(scale_default, scale_false)) |
| 796 | + |
| 797 | + # Test that they produce the same results |
| 798 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 799 | + with torch.no_grad(): |
| 800 | + output_default = quantized_model_default(input_tensor) |
| 801 | + output_false = quantized_model_false(input_tensor) |
| 802 | + |
| 803 | + # Outputs should be identical |
| 804 | + self.assertTrue(torch.allclose(output_default, output_false)) |
| 805 | + |
| 806 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 807 | + @unittest.skipIf( |
| 808 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 809 | + ) |
| 810 | + def test_power_of_2_vs_regular_scaling(self): |
| 811 | + """Test that power of 2 scaling produces different (but reasonable) results compared to regular scaling""" |
| 812 | + device = "cuda" |
| 813 | + dtype = torch.bfloat16 |
| 814 | + |
| 815 | + # Create model |
| 816 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 817 | + |
| 818 | + # Test with regular scaling |
| 819 | + config_regular = Float8WeightOnlyConfig(round_scales_to_power_of_2=False) |
| 820 | + quantized_model_regular = copy.deepcopy(model) |
| 821 | + quantize_(quantized_model_regular, config_regular) |
| 822 | + |
| 823 | + # Test with power of 2 scaling |
| 824 | + config_power2 = Float8WeightOnlyConfig(round_scales_to_power_of_2=True) |
| 825 | + quantized_model_power2 = copy.deepcopy(model) |
| 826 | + quantize_(quantized_model_power2, config_power2) |
| 827 | + |
| 828 | + # Get scales from both models |
| 829 | + scale_regular = quantized_model_regular.weight.tensor_impl.scale.float() |
| 830 | + scale_power2 = quantized_model_power2.weight.tensor_impl.scale.float() |
| 831 | + |
| 832 | + # Power of 2 scale should be different from regular scale (unless it was already power of 2) |
| 833 | + # But the power of 2 scale should be <= regular scale (since we round down) |
| 834 | + self.assertTrue(torch.all(scale_power2 <= scale_regular)) |
| 835 | + |
| 836 | + # Verify power of 2 scale is actually power of 2 |
| 837 | + log2_scale = torch.log2(scale_power2) |
| 838 | + is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6) |
| 839 | + self.assertTrue(is_power_of_2, "Power-of-2 scales should be powers of 2") |
| 840 | + |
| 841 | + # Test that both produce reasonable results |
| 842 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 843 | + with torch.no_grad(): |
| 844 | + ref_output = model(input_tensor) |
| 845 | + output_regular = quantized_model_regular(input_tensor) |
| 846 | + output_power2 = quantized_model_power2(input_tensor) |
| 847 | + |
| 848 | + # Both should have reasonable quantization error |
| 849 | + error_regular = compute_error(ref_output, output_regular) |
| 850 | + error_power2 = compute_error(ref_output, output_power2) |
| 851 | + |
| 852 | + self.assertGreater( |
| 853 | + error_regular, 15, f"Regular quantization SQNR too low: {error_regular}" |
| 854 | + ) |
| 855 | + self.assertGreater( |
| 856 | + error_power2, 15, f"Power-of-2 quantization SQNR too low: {error_power2}" |
| 857 | + ) |
| 858 | + |
633 | 859 |
|
634 | 860 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
|
635 | 861 |
|
|
0 commit comments