40
40
ONNXRT1161_VERSION = Version ("1.16.1" )
41
41
42
42
43
- def get_blob_size (group_size , has_zp ): # pragma: no cover
43
+ def get_blob_size (group_size , num_bits , has_zp ): # pragma: no cover
44
44
"""Get blob_size.
45
45
46
46
Args:
47
47
group_size (int): how many elements share one scale/zp
48
48
has_zp (bool): whether zero_point is None
49
49
"""
50
50
if Version (ort .__version__ ) > ONNXRT1161_VERSION :
51
- blob_size = group_size // 2
51
+ blob_size = group_size * num_bits // 8
52
52
elif has_zp :
53
- blob_size = group_size // 2 + 4 + 1
53
+ blob_size = group_size * num_bits // 8 + 4 + 1
54
54
else :
55
- blob_size = group_size // 2 + 4
55
+ blob_size = group_size * num_bits // 8 + 4
56
56
return blob_size
57
57
58
58
@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
86
86
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
87
87
new_inits: initializers of the new node
88
88
"""
89
- blob_size = get_blob_size (group_size , zero_point is not None )
89
+ blob_size = get_blob_size (group_size , num_bits , zero_point is not None )
90
90
packed = np .zeros ((q_weight .shape [0 ], blob_size ), dtype = "uint8" )
91
91
q_weight_name = node .input [1 ] + "_Q{}G{}" .format (str (num_bits ), str (group_size ))
92
92
input_names = [node .input [0 ], q_weight_name ]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
97
97
op_type = "MatMulNBits"
98
98
99
99
# pack quantized weight
100
- q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
101
- packed [:, :] = q_weight_pairs [:, :blob_size ]
100
+ if num_bits == 4 :
101
+ q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
102
+ packed [:, :] = q_weight_pairs [:, :blob_size ]
103
+ elif num_bits == 8 :
104
+ packed = q_weight
105
+ else :
106
+ logger .error (
107
+ "MatMulNBits does not have kernel support for num_bits = {}." .format (num_bits )
108
+ )
109
+
102
110
packed = np .reshape (packed , (- 1 , k_blocks , blob_size ))
103
111
104
112
# build scale tensor
@@ -527,7 +535,8 @@ def rtn_quantize(
527
535
528
536
weight = pad_tensor (weight , group_size , k_blocks )
529
537
530
- satisfy_MatMulNBits_condition = Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4
538
+ enable_MatMulNBits_8bits = True
539
+ satisfy_MatMulNBits_condition = (Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4 ) or (enable_MatMulNBits_8bits and num_bits == 8 )
531
540
satisfy_MatMulFpQ4_condition = (
532
541
Version (ort .__version__ ) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533
542
)
0 commit comments