@@ -1281,26 +1281,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
1281
1281
def _float8_cutlass_quant (
1282
1282
x : torch .Tensor ,
1283
1283
target_dtype : torch .dtype ,
1284
+ round_scales_to_power_of_2 : bool = False ,
1284
1285
) -> torch .Tensor :
1285
1286
return to_affine_quantized_floatx (
1286
1287
x ,
1287
1288
block_size = _get_per_token_block_size (x ),
1288
1289
scale_dtype = torch .float32 ,
1289
1290
target_dtype = target_dtype ,
1290
1291
_layout = Float8Layout (mm_config = None ),
1292
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1291
1293
)
1292
1294
1293
1295
1294
1296
def _float8_cutlass_quant_sparse (
1295
1297
x : torch .Tensor ,
1296
1298
target_dtype : torch .dtype ,
1299
+ round_scales_to_power_of_2 : bool = False ,
1297
1300
) -> (torch .Tensor , torch .Tensor ):
1298
1301
return to_affine_quantized_floatx (
1299
1302
x ,
1300
1303
block_size = _get_per_token_block_size (x ),
1301
1304
scale_dtype = torch .float32 ,
1302
1305
target_dtype = target_dtype ,
1303
1306
_layout = CutlassSemiSparseLayout (),
1307
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1304
1308
)
1305
1309
1306
1310
@@ -1410,13 +1414,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
1410
1414
Args:
1411
1415
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
1412
1416
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1417
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1413
1418
1414
1419
Note:
1415
1420
The actual matmul will be computed in original precision of the weight tensor.
1416
1421
"""
1417
1422
1418
1423
weight_dtype : torch .dtype = e4m3_dtype
1419
1424
set_inductor_config : bool = True
1425
+ round_scales_to_power_of_2 : bool = False
1420
1426
1421
1427
1422
1428
# for BC
@@ -1433,6 +1439,7 @@ def _float8_weight_only_quant_tensor(weight, config):
1433
1439
target_dtype = config .weight_dtype ,
1434
1440
scale_dtype = None ,
1435
1441
_layout = Float8Layout (mm_config = None ),
1442
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1436
1443
)
1437
1444
return new_weight
1438
1445
@@ -1461,6 +1468,7 @@ def _input_activation_quant_func_fp8(
1461
1468
activation_dtype : torch .dtype ,
1462
1469
scale : Optional [torch .Tensor ] = None ,
1463
1470
zero_point : Optional [torch .Tensor ] = None ,
1471
+ round_scales_to_power_of_2 : bool = False ,
1464
1472
):
1465
1473
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
1466
1474
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1481,6 +1489,7 @@ def _input_activation_quant_func_fp8(
1481
1489
target_dtype = activation_dtype ,
1482
1490
scale_dtype = torch .float32 ,
1483
1491
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1492
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1484
1493
)
1485
1494
else :
1486
1495
assert isinstance (activation_granularity , PerTensor ), (
@@ -1538,6 +1547,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1538
1547
only PerTensor and PerRow are supported.
1539
1548
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1540
1549
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1550
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1541
1551
1542
1552
"""
1543
1553
@@ -1546,6 +1556,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1546
1556
granularity : Optional [Union [FP8Granularity , List [FP8Granularity ]]] = None
1547
1557
mm_config : Optional [Float8MMConfig ] = None
1548
1558
set_inductor_config : bool = True
1559
+ round_scales_to_power_of_2 : bool = False
1549
1560
1550
1561
def __post_init__ (self ):
1551
1562
if self .mm_config is None :
@@ -1589,12 +1600,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
1589
1600
target_dtype = weight_dtype ,
1590
1601
scale_dtype = torch .float32 ,
1591
1602
_layout = Float8Layout (mm_config = mm_config ),
1603
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1592
1604
)
1593
1605
1594
1606
input_quant_func = _input_activation_quant_func_fp8
1595
1607
input_quant_kwargs = {
1596
1608
"activation_granularity" : activation_granularity ,
1597
1609
"activation_dtype" : activation_dtype ,
1610
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1598
1611
}
1599
1612
1600
1613
quantized_weight = to_linear_activation_quantized (
@@ -1634,11 +1647,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
1634
1647
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
1635
1648
`activation_dtype`: data type for quantized activation tensor.
1636
1649
`weight_dtype`: data type for quantized weight tensor.
1650
+ `round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
1637
1651
"""
1638
1652
1639
1653
layout : Layout = CutlassSemiSparseLayout ()
1640
1654
activation_dtype : torch .dtype = e5m2_dtype
1641
1655
weight_dtype : torch .dtype = e4m3_dtype
1656
+ round_scales_to_power_of_2 : bool = False
1642
1657
1643
1658
1644
1659
@register_quantize_module_handler (Float8DynamicActivationFloat8SemiSparseWeightConfig )
@@ -1657,11 +1672,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
1657
1672
f"Only CutlassSemiSparseLayout layout is supported. Received { layout } ."
1658
1673
)
1659
1674
1660
- weight = _float8_cutlass_quant_sparse (weight , weight_dtype )
1675
+ weight = _float8_cutlass_quant_sparse (
1676
+ weight , weight_dtype , config .round_scales_to_power_of_2
1677
+ )
1661
1678
weight = to_linear_activation_quantized (
1662
1679
weight ,
1663
1680
_float8_cutlass_quant ,
1664
- quant_kwargs = {"target_dtype" : activation_dtype },
1681
+ quant_kwargs = {
1682
+ "target_dtype" : activation_dtype ,
1683
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1684
+ },
1665
1685
)
1666
1686
1667
1687
module .weight = torch .nn .Parameter (weight , requires_grad = False )
@@ -1680,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1680
1700
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
1681
1701
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1682
1702
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1703
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1683
1704
"""
1684
1705
1685
1706
scale : torch .Tensor
@@ -1690,6 +1711,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1690
1711
] = None
1691
1712
mm_config : Optional [Float8MMConfig ] = None
1692
1713
set_inductor_config : bool = True
1714
+ round_scales_to_power_of_2 : bool = False
1693
1715
1694
1716
def __post_init__ (self ):
1695
1717
if self .mm_config is None :
@@ -1733,12 +1755,14 @@ def _float8_static_activation_float8_weight_transform(
1733
1755
target_dtype = weight_dtype ,
1734
1756
scale_dtype = torch .float32 ,
1735
1757
_layout = Float8Layout (mm_config = mm_config ),
1758
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1736
1759
)
1737
1760
1738
1761
input_quant_func = _input_activation_quant_func_fp8
1739
1762
input_quant_kwargs = {
1740
1763
"activation_granularity" : activation_granularity ,
1741
1764
"activation_dtype" : activation_dtype ,
1765
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1742
1766
}
1743
1767
1744
1768
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata (
0 commit comments