@@ -1270,26 +1270,30 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor:
1270
1270
def _float8_cutlass_quant (
1271
1271
x : torch .Tensor ,
1272
1272
target_dtype : torch .dtype ,
1273
+ round_scales_to_power_of_2 : bool = False ,
1273
1274
) -> torch .Tensor :
1274
1275
return to_affine_quantized_floatx (
1275
1276
x ,
1276
1277
block_size = _get_per_token_block_size (x ),
1277
1278
scale_dtype = torch .float32 ,
1278
1279
target_dtype = target_dtype ,
1279
1280
_layout = Float8Layout (mm_config = None ),
1281
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1280
1282
)
1281
1283
1282
1284
1283
1285
def _float8_cutlass_quant_sparse (
1284
1286
x : torch .Tensor ,
1285
1287
target_dtype : torch .dtype ,
1288
+ round_scales_to_power_of_2 : bool = False ,
1286
1289
) -> (torch .Tensor , torch .Tensor ):
1287
1290
return to_affine_quantized_floatx (
1288
1291
x ,
1289
1292
block_size = _get_per_token_block_size (x ),
1290
1293
scale_dtype = torch .float32 ,
1291
1294
target_dtype = target_dtype ,
1292
1295
_layout = CutlassSemiSparseLayout (),
1296
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1293
1297
)
1294
1298
1295
1299
@@ -1399,13 +1403,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
1399
1403
Args:
1400
1404
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
1401
1405
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1406
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1402
1407
1403
1408
Note:
1404
1409
The actual matmul will be computed in original precision of the weight tensor.
1405
1410
"""
1406
1411
1407
1412
weight_dtype : torch .dtype = e4m3_dtype
1408
1413
set_inductor_config : bool = True
1414
+ round_scales_to_power_of_2 : bool = False
1409
1415
1410
1416
1411
1417
# for BC
@@ -1422,6 +1428,7 @@ def _float8_weight_only_quant_tensor(weight, config):
1422
1428
target_dtype = config .weight_dtype ,
1423
1429
scale_dtype = None ,
1424
1430
_layout = Float8Layout (mm_config = None ),
1431
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1425
1432
)
1426
1433
return new_weight
1427
1434
@@ -1450,6 +1457,7 @@ def _input_activation_quant_func_fp8(
1450
1457
activation_dtype : torch .dtype ,
1451
1458
scale : Optional [torch .Tensor ] = None ,
1452
1459
zero_point : Optional [torch .Tensor ] = None ,
1460
+ round_scales_to_power_of_2 : bool = False ,
1453
1461
):
1454
1462
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
1455
1463
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
@@ -1470,6 +1478,7 @@ def _input_activation_quant_func_fp8(
1470
1478
target_dtype = activation_dtype ,
1471
1479
scale_dtype = torch .float32 ,
1472
1480
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1481
+ round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
1473
1482
)
1474
1483
else :
1475
1484
assert isinstance (activation_granularity , PerTensor ), (
@@ -1527,6 +1536,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1527
1536
only PerTensor and PerRow are supported.
1528
1537
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1529
1538
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1539
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1530
1540
1531
1541
"""
1532
1542
@@ -1535,6 +1545,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
1535
1545
granularity : Optional [Union [FP8Granularity , List [FP8Granularity ]]] = None
1536
1546
mm_config : Optional [Float8MMConfig ] = None
1537
1547
set_inductor_config : bool = True
1548
+ round_scales_to_power_of_2 : bool = False
1538
1549
1539
1550
def __post_init__ (self ):
1540
1551
if self .mm_config is None :
@@ -1578,12 +1589,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
1578
1589
target_dtype = weight_dtype ,
1579
1590
scale_dtype = torch .float32 ,
1580
1591
_layout = Float8Layout (mm_config = mm_config ),
1592
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1581
1593
)
1582
1594
1583
1595
input_quant_func = _input_activation_quant_func_fp8
1584
1596
input_quant_kwargs = {
1585
1597
"activation_granularity" : activation_granularity ,
1586
1598
"activation_dtype" : activation_dtype ,
1599
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1587
1600
}
1588
1601
1589
1602
quantized_weight = to_linear_activation_quantized (
@@ -1623,11 +1636,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
1623
1636
`layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment.
1624
1637
`activation_dtype`: data type for quantized activation tensor.
1625
1638
`weight_dtype`: data type for quantized weight tensor.
1639
+ `round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2.
1626
1640
"""
1627
1641
1628
1642
layout : Layout = CutlassSemiSparseLayout ()
1629
1643
activation_dtype : torch .dtype = e5m2_dtype
1630
1644
weight_dtype : torch .dtype = e4m3_dtype
1645
+ round_scales_to_power_of_2 : bool = False
1631
1646
1632
1647
1633
1648
@register_quantize_module_handler (Float8DynamicActivationFloat8SemiSparseWeightConfig )
@@ -1646,11 +1661,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
1646
1661
f"Only CutlassSemiSparseLayout layout is supported. Received { layout } ."
1647
1662
)
1648
1663
1649
- weight = _float8_cutlass_quant_sparse (weight , weight_dtype )
1664
+ weight = _float8_cutlass_quant_sparse (
1665
+ weight , weight_dtype , config .round_scales_to_power_of_2
1666
+ )
1650
1667
weight = to_linear_activation_quantized (
1651
1668
weight ,
1652
1669
_float8_cutlass_quant ,
1653
- quant_kwargs = {"target_dtype" : activation_dtype },
1670
+ quant_kwargs = {
1671
+ "target_dtype" : activation_dtype ,
1672
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1673
+ },
1654
1674
)
1655
1675
1656
1676
module .weight = torch .nn .Parameter (weight , requires_grad = False )
@@ -1669,6 +1689,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1669
1689
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
1670
1690
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1671
1691
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1692
+ round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2.
1672
1693
"""
1673
1694
1674
1695
scale : torch .Tensor
@@ -1679,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
1679
1700
] = None
1680
1701
mm_config : Optional [Float8MMConfig ] = None
1681
1702
set_inductor_config : bool = True
1703
+ round_scales_to_power_of_2 : bool = False
1682
1704
1683
1705
def __post_init__ (self ):
1684
1706
if self .mm_config is None :
@@ -1722,12 +1744,14 @@ def _float8_static_activation_float8_weight_transform(
1722
1744
target_dtype = weight_dtype ,
1723
1745
scale_dtype = torch .float32 ,
1724
1746
_layout = Float8Layout (mm_config = mm_config ),
1747
+ round_scales_to_power_of_2 = config .round_scales_to_power_of_2 ,
1725
1748
)
1726
1749
1727
1750
input_quant_func = _input_activation_quant_func_fp8
1728
1751
input_quant_kwargs = {
1729
1752
"activation_granularity" : activation_granularity ,
1730
1753
"activation_dtype" : activation_dtype ,
1754
+ "round_scales_to_power_of_2" : config .round_scales_to_power_of_2 ,
1731
1755
}
1732
1756
1733
1757
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata (
0 commit comments