@@ -1274,26 +1274,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
1274
1274
def _float8_cutlass_quant (
1275
1275
x : torch .Tensor ,
1276
1276
target_dtype : torch .dtype ,
1277
+ round_scales_to_power_of_2 : bool = False ,
1277
1278
) -> torch .Tensor :
1278
1279
return to_affine_quantized_floatx (
1279
1280
x ,
1280
1281
block_size = _get_per_token_block_size (x ),
1281
1282
scale_dtype = torch .float32 ,
1282
1283
target_dtype = target_dtype ,
1283
1284
_layout = Float8Layout (mm_config = None ),
1285
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1284
1286
)
1285
1287
1286
1288
1287
1289
def _float8_cutlass_quant_sparse (
1288
1290
x : torch .Tensor ,
1289
1291
target_dtype : torch .dtype ,
1292
+ round_scales_to_power_of_2 : bool = False ,
1290
1293
) -> (torch .Tensor , torch .Tensor ):
1291
1294
return to_affine_quantized_floatx (
1292
1295
x ,
1293
1296
block_size = _get_per_token_block_size (x ),
1294
1297
scale_dtype = torch .float32 ,
1295
1298
target_dtype = target_dtype ,
1296
1299
_layout = CutlassSemiSparseLayout (),
1300
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1297
1301
)
1298
1302
1299
1303
@@ -1403,13 +1407,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
1403
1407
Args:
1404
1408
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
1405
1409
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1410
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1406
1411
1407
1412
Note:
1408
1413
The actual matmul will be computed in original precision of the weight tensor.
1409
1414
"""
1410
1415
1411
1416
weight_dtype : torch .dtype = e4m3_dtype
1412
1417
set_inductor_config : bool = True
1418
+ round_scales_to_power_of_2 : bool = False
1413
1419
1414
1420
1415
1421
# for BC
@@ -1426,6 +1432,7 @@ def _float8_weight_only_quant_tensor(weight, config):
1426
1432
target_dtype = config .weight_dtype ,
1427
1433
scale_dtype = None ,
1428
1434
_layout = Float8Layout (mm_config = None ),
1435
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1429
1436
)
1430
1437
return new_weight
1431
1438
@@ -1454,6 +1461,7 @@ def _input_activation_quant_func_fp8(
1454
1461
activation_dtype : torch .dtype ,
1455
1462
scale : Optional [torch .Tensor ] = None ,
1456
1463
zero_point : Optional [torch .Tensor ] = None ,
1464
+ round_scales_to_power_of_2 : bool = False ,
1457
1465
):
1458
1466
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
1459
1467
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1474,6 +1482,7 @@ def _input_activation_quant_func_fp8(
1474
1482
target_dtype = activation_dtype ,
1475
1483
scale_dtype = torch .float32 ,
1476
1484
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1485
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1477
1486
)
1478
1487
else :
1479
1488
assert isinstance (activation_granularity , PerTensor ), (
@@ -1531,6 +1540,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1531
1540
only PerTensor and PerRow are supported.
1532
1541
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1533
1542
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1543
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1534
1544
1535
1545
"""
1536
1546
@@ -1539,6 +1549,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1539
1549
granularity : Optional [Union [FP8Granularity , List [FP8Granularity ]]] = None
1540
1550
mm_config : Optional [Float8MMConfig ] = None
1541
1551
set_inductor_config : bool = True
1552
+ round_scales_to_power_of_2 : bool = False
1542
1553
1543
1554
def __post_init__ (self ):
1544
1555
if self .mm_config is None :
@@ -1582,12 +1593,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
1582
1593
target_dtype = weight_dtype ,
1583
1594
scale_dtype = torch .float32 ,
1584
1595
_layout = Float8Layout (mm_config = mm_config ),
1596
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1585
1597
)
1586
1598
1587
1599
input_quant_func = _input_activation_quant_func_fp8
1588
1600
input_quant_kwargs = {
1589
1601
"activation_granularity" : activation_granularity ,
1590
1602
"activation_dtype" : activation_dtype ,
1603
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1591
1604
}
1592
1605
1593
1606
quantized_weight = to_linear_activation_quantized (
@@ -1627,11 +1640,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
1627
1640
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
1628
1641
`activation_dtype`: data type for quantized activation tensor.
1629
1642
`weight_dtype`: data type for quantized weight tensor.
1643
+ `round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
1630
1644
"""
1631
1645
1632
1646
layout : Layout = CutlassSemiSparseLayout ()
1633
1647
activation_dtype : torch .dtype = e5m2_dtype
1634
1648
weight_dtype : torch .dtype = e4m3_dtype
1649
+ round_scales_to_power_of_2 : bool = False
1635
1650
1636
1651
1637
1652
@register_quantize_module_handler (Float8DynamicActivationFloat8SemiSparseWeightConfig )
@@ -1650,11 +1665,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
1650
1665
f"Only CutlassSemiSparseLayout layout is supported. Received { layout } ."
1651
1666
)
1652
1667
1653
- weight = _float8_cutlass_quant_sparse (weight , weight_dtype )
1668
+ weight = _float8_cutlass_quant_sparse (
1669
+ weight , weight_dtype , config .round_scales_to_power_of_2
1670
+ )
1654
1671
weight = to_linear_activation_quantized (
1655
1672
weight ,
1656
1673
_float8_cutlass_quant ,
1657
- quant_kwargs = {"target_dtype" : activation_dtype },
1674
+ quant_kwargs = {
1675
+ "target_dtype" : activation_dtype ,
1676
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1677
+ },
1658
1678
)
1659
1679
1660
1680
module .weight = torch .nn .Parameter (weight , requires_grad = False )
@@ -1673,6 +1693,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1673
1693
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
1674
1694
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1675
1695
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1696
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1676
1697
"""
1677
1698
1678
1699
scale : torch .Tensor
@@ -1683,6 +1704,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1683
1704
] = None
1684
1705
mm_config : Optional [Float8MMConfig ] = None
1685
1706
set_inductor_config : bool = True
1707
+ round_scales_to_power_of_2 : bool = False
1686
1708
1687
1709
def __post_init__ (self ):
1688
1710
if self .mm_config is None :
@@ -1726,12 +1748,14 @@ def _float8_static_activation_float8_weight_transform(
1726
1748
target_dtype = weight_dtype ,
1727
1749
scale_dtype = torch .float32 ,
1728
1750
_layout = Float8Layout (mm_config = mm_config ),
1751
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1729
1752
)
1730
1753
1731
1754
input_quant_func = _input_activation_quant_func_fp8
1732
1755
input_quant_kwargs = {
1733
1756
"activation_granularity" : activation_granularity ,
1734
1757
"activation_dtype" : activation_dtype ,
1758
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1735
1759
}
1736
1760
1737
1761
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata (
0 commit comments