Skip to content

Commit 559e1d9

Browse files
committed
Merge branch 'int8_new' into k_quant
2 parents 99f10df + 0a1a0d4 commit 559e1d9

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@
4040
ONNXRT1161_VERSION = Version("1.16.1")
4141

4242

43-
def get_blob_size(group_size, has_zp): # pragma: no cover
43+
def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover
4444
"""Get blob_size.
4545
4646
Args:
4747
group_size (int): how many elements share one scale/zp
4848
has_zp (bool): whether zero_point is None
4949
"""
5050
if Version(ort.__version__) > ONNXRT1161_VERSION:
51-
blob_size = group_size // 2
51+
blob_size = group_size * num_bits // 8
5252
elif has_zp:
53-
blob_size = group_size // 2 + 4 + 1
53+
blob_size = group_size * num_bits // 8 + 4 + 1
5454
else:
55-
blob_size = group_size // 2 + 4
55+
blob_size = group_size * num_bits // 8 + 4
5656
return blob_size
5757

5858

@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
8686
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
8787
new_inits: initializers of the new node
8888
"""
89-
blob_size = get_blob_size(group_size, zero_point is not None)
89+
blob_size = get_blob_size(group_size, num_bits, zero_point is not None)
9090
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
9191
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
9292
input_names = [node.input[0], q_weight_name]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
9797
op_type = "MatMulNBits"
9898

9999
# pack quantized weight
100-
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
101-
packed[:, :] = q_weight_pairs[:, :blob_size]
100+
if num_bits == 4:
101+
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
102+
packed[:, :] = q_weight_pairs[:, :blob_size]
103+
elif num_bits == 8:
104+
packed = q_weight
105+
else:
106+
logger.error(
107+
"MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)
108+
)
109+
102110
packed = np.reshape(packed, (-1, k_blocks, blob_size))
103111

104112
# build scale tensor
@@ -527,7 +535,8 @@ def rtn_quantize(
527535

528536
weight = pad_tensor(weight, group_size, k_blocks)
529537

530-
satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4
538+
enable_MatMulNBits_8bits = True
539+
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8)
531540
satisfy_MatMulFpQ4_condition = (
532541
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533542
)

0 commit comments

Comments
 (0)