Skip to content

Commit 0ae992f

Browse files
committed
Merge branch 'int8_new' into k_quant
2 parents 99f10df + 0a1a0d4 commit 0ae992f

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@
4040
ONNXRT1161_VERSION = Version("1.16.1")
4141

4242

43-
def get_blob_size(group_size, has_zp): # pragma: no cover
43+
def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover
4444
"""Get blob_size.
4545
4646
Args:
4747
group_size (int): how many elements share one scale/zp
4848
has_zp (bool): whether zero_point is None
4949
"""
5050
if Version(ort.__version__) > ONNXRT1161_VERSION:
51-
blob_size = group_size // 2
51+
blob_size = group_size * num_bits // 8
5252
elif has_zp:
53-
blob_size = group_size // 2 + 4 + 1
53+
blob_size = group_size * num_bits // 8 + 4 + 1
5454
else:
55-
blob_size = group_size // 2 + 4
55+
blob_size = group_size * num_bits // 8 + 4
5656
return blob_size
5757

5858

@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
8686
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
8787
new_inits: initializers of the new node
8888
"""
89-
blob_size = get_blob_size(group_size, zero_point is not None)
89+
blob_size = get_blob_size(group_size, num_bits, zero_point is not None)
9090
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
9191
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
9292
input_names = [node.input[0], q_weight_name]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
9797
op_type = "MatMulNBits"
9898

9999
# pack quantized weight
100-
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
101-
packed[:, :] = q_weight_pairs[:, :blob_size]
100+
if num_bits == 4:
101+
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
102+
packed[:, :] = q_weight_pairs[:, :blob_size]
103+
elif num_bits == 8:
104+
packed = q_weight
105+
else:
106+
logger.error(
107+
"MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)
108+
)
109+
102110
packed = np.reshape(packed, (-1, k_blocks, blob_size))
103111

104112
# build scale tensor
@@ -115,8 +123,10 @@ def make_matmul_weight_only_node(
115123

116124
# build zero_point tensor
117125
if zero_point is not None:
118-
if num_bits > 4:
119-
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
126+
if num_bits == 8:
127+
packed_zp = zero_point.astype("uint8")
128+
elif num_bits > 4:
129+
packed_zp = np.reshape(zero_point, (scale.shape[0], -1)).astype("uint8")
120130
else:
121131
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
122132
# create an index array
@@ -463,7 +473,7 @@ def rtn_quantize(
463473
ratios={},
464474
accuracy_level=0,
465475
providers=["CPUExecutionProvider"],
466-
algorithm="rtn",
476+
algorithm="k_quant",
467477
):
468478
"""Quant the model with round to nearst method.
469479
@@ -527,7 +537,8 @@ def rtn_quantize(
527537

528538
weight = pad_tensor(weight, group_size, k_blocks)
529539

530-
satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4
540+
enable_MatMulNBits_8bits = True
541+
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8)
531542
satisfy_MatMulFpQ4_condition = (
532543
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533544
)

0 commit comments

Comments
 (0)