Skip to content

Add round_scales_to_power_of_2 option for float quantization #2323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ jobs:
include:
- name: SM-89
runs-on: linux.g6.4xlarge.experimental.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu128'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
gpu-arch-version: "12.8"
- name: H100
runs-on: linux.aws.h100
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126'
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128'
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
gpu-arch-version: "12.8"
permissions:
id-token: write
contents: read
Expand Down
97 changes: 97 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_weight_only,
quantize_,
Expand Down Expand Up @@ -675,6 +678,100 @@ def test_preprocess_scale_3d_reshape(self):
expected_shape = (8, 1) # Flattened (2*2*2, 1)
self.assertEqual(result.shape, expected_shape)

def _get_weight_scale_and_impl(self, quantized_model, config_type):
"""Helper to extract weight scale and impl based on config type"""
if config_type == "weight_only":
weight_impl = quantized_model.weight.tensor_impl
return weight_impl.scale.float(), weight_impl
else:
weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl
return weight_impl.scale.float(), weight_impl

def _verify_power_of_2_scales(self, scale):
"""Helper to verify scales are powers of 2"""
log2_scale = torch.log2(scale)
is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6)
self.assertTrue(is_power_of_2, "Scales should be powers of 2")

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize(
"config_factory,config_type,min_sm,min_error",
[
(
lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True),
"weight_only",
89,
15.0,
),
(
lambda: Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(), round_scales_to_power_of_2=True
),
"dynamic",
89,
12.5,
),
(
lambda: Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), round_scales_to_power_of_2=True
),
"dynamic",
89,
12.5,
),
(
lambda: Float8StaticActivationFloat8WeightConfig(
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
granularity=PerTensor(),
round_scales_to_power_of_2=True,
),
"static",
89,
15.0,
),
(
lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig(
round_scales_to_power_of_2=True
),
"dynamic",
90,
10.0,
),
],
)
def test_power_of_2_scaling_configs(
self, config_factory, config_type, min_sm, min_error
):
if min_sm == 90 and not is_sm_at_least_90():
self.skipTest("Requires GPU with compute capability >= 9.0")

device = "cuda"
dtype = torch.bfloat16
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)

config = config_factory()
if isinstance(
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the min_sm check above already handle this? Or is this because the cutlass kernel is only built on sm90 / 90a?

Test case for Float8SemiSparse is failing on h100 due to CUDA backend not supporting operator (kernel not built?) so just wondering

) and not is_sm_version(9, 0):
self.skipTest("Float8SemiSparse requires compute capability == 9.0")
quantized_model = copy.deepcopy(model)
quantize_(quantized_model, config)

scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type)
self._verify_power_of_2_scales(scale)

input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
with torch.no_grad():
ref_output = model(input_tensor)
quant_output = quantized_model(input_tensor)

self.assertEqual(ref_output.shape, quant_output.shape)
error = compute_error(ref_output, quant_output)
self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}")


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

Expand Down
11 changes: 9 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,17 @@ def from_hp_to_floatx(
target_dtype: torch.dtype,
_layout: Layout,
scale_dtype: Optional[torch.dtype] = None,
round_scales_to_power_of_2: bool = False,
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype in FP8_TYPES:
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
scale = _choose_qparams_affine_float8(
input_float, float8_dtype=target_dtype, block_size=block_size
input_float,
float8_dtype=target_dtype,
block_size=block_size,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
data = _quantize_affine_float8(input_float, scale, target_dtype)
data, scale, zero_point = _layout.post_process(
Expand Down Expand Up @@ -530,6 +534,7 @@ def from_hp_to_fpx(
cls,
input_float: torch.Tensor,
_layout: Layout,
round_scales_to_power_of_2: bool = False,
):
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
from torchao.dtypes.floatx import FloatxTensorCoreLayout
Expand All @@ -545,7 +550,9 @@ def from_hp_to_fpx(

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = _choose_qparams_affine_floatx(input_float, ebits, mbits)
scale = _choose_qparams_affine_floatx(
input_float, ebits, mbits, round_scales_to_power_of_2
)
floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed, scale, _ = _layout.post_process(
floatx_unpacked, scale, None, block_size
Expand Down
5 changes: 4 additions & 1 deletion torchao/dtypes/floatx/floatx_tensor_core_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AQTTensorImpl,
Layout,
)
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
Expand Down Expand Up @@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:


def to_scaled_tc_floatx(
tensor: Tensor, ebits: int, mbits: int
tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False
) -> Tuple[Tensor, Tensor]:
# _n_ones() is not compatible with torch.compile() due to << operator
# https://github.com/pytorch/pytorch/issues/119152
Expand All @@ -230,6 +231,8 @@ def to_scaled_tc_floatx(
dtype = tensor.dtype
tensor = tensor.float()
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
if round_scales_to_power_of_2:
scale = _round_scale_down_to_power_of_2(scale.float())
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits)
return tensor_tc_floatx, scale.to(dtype)
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def pad_tensor_for_matmul(
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))


def _round_scale_down_to_power_of_2(scale: torch.Tensor):
def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor:
"""Rounds the scale down to the nearest power of 2."""
assert scale.dtype == torch.float32, "scale must be float32 tensor"
return torch.exp2(torch.floor(torch.log2(scale)))
28 changes: 26 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,26 +1281,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
def _float8_cutlass_quant(
x: torch.Tensor,
target_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
) -> torch.Tensor:
return to_affine_quantized_floatx(
x,
block_size=_get_per_token_block_size(x),
scale_dtype=torch.float32,
target_dtype=target_dtype,
_layout=Float8Layout(mm_config=None),
round_scales_to_power_of_2=round_scales_to_power_of_2,
)


def _float8_cutlass_quant_sparse(
x: torch.Tensor,
target_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
) -> (torch.Tensor, torch.Tensor):
return to_affine_quantized_floatx(
x,
block_size=_get_per_token_block_size(x),
scale_dtype=torch.float32,
target_dtype=target_dtype,
_layout=CutlassSemiSparseLayout(),
round_scales_to_power_of_2=round_scales_to_power_of_2,
)


Expand Down Expand Up @@ -1410,13 +1414,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
Args:
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.

Note:
The actual matmul will be computed in original precision of the weight tensor.
"""

weight_dtype: torch.dtype = e4m3_dtype
set_inductor_config: bool = True
round_scales_to_power_of_2: bool = False


# for BC
Expand All @@ -1433,6 +1439,7 @@ def _float8_weight_only_quant_tensor(weight, config):
target_dtype=config.weight_dtype,
scale_dtype=None,
_layout=Float8Layout(mm_config=None),
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
)
return new_weight

Expand Down Expand Up @@ -1461,6 +1468,7 @@ def _input_activation_quant_func_fp8(
activation_dtype: torch.dtype,
scale: Optional[torch.Tensor] = None,
zero_point: Optional[torch.Tensor] = None,
round_scales_to_power_of_2: bool = False,
):
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
Expand All @@ -1481,6 +1489,7 @@ def _input_activation_quant_func_fp8(
target_dtype=activation_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=None), # Config is stored on weight
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
else:
assert isinstance(activation_granularity, PerTensor), (
Expand Down Expand Up @@ -1538,6 +1547,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
only PerTensor and PerRow are supported.
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.

"""

Expand All @@ -1546,6 +1556,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
mm_config: Optional[Float8MMConfig] = None
set_inductor_config: bool = True
round_scales_to_power_of_2: bool = False

def __post_init__(self):
if self.mm_config is None:
Expand Down Expand Up @@ -1589,12 +1600,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
}

quantized_weight = to_linear_activation_quantized(
Expand Down Expand Up @@ -1634,11 +1647,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
`activation_dtype`: data type for quantized activation tensor.
`weight_dtype`: data type for quantized weight tensor.
`round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
"""

layout: Layout = CutlassSemiSparseLayout()
activation_dtype: torch.dtype = e5m2_dtype
weight_dtype: torch.dtype = e4m3_dtype
round_scales_to_power_of_2: bool = False


@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
Expand All @@ -1657,11 +1672,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
)

weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
weight = _float8_cutlass_quant_sparse(
weight, weight_dtype, config.round_scales_to_power_of_2
)
weight = to_linear_activation_quantized(
weight,
_float8_cutlass_quant,
quant_kwargs={"target_dtype": activation_dtype},
quant_kwargs={
"target_dtype": activation_dtype,
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
},
)

module.weight = torch.nn.Parameter(weight, requires_grad=False)
Expand All @@ -1680,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
"""

scale: torch.Tensor
Expand All @@ -1690,6 +1711,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
] = None
mm_config: Optional[Float8MMConfig] = None
set_inductor_config: bool = True
round_scales_to_power_of_2: bool = False

def __post_init__(self):
if self.mm_config is None:
Expand Down Expand Up @@ -1733,12 +1755,14 @@ def _float8_static_activation_float8_weight_transform(
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
round_scales_to_power_of_2=config.round_scales_to_power_of_2,
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
"round_scales_to_power_of_2": config.round_scales_to_power_of_2,
}

quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
Expand Down
Loading
Loading