@@ -260,7 +260,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
260
260
scale: scale
261
261
zero_point: zero point
262
262
"""
263
- data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
263
+ data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # nb = data.shape[0], (nb, group_size)
264
264
maxq = 2 ** num_bits - 1
265
265
minq = 0
266
266
sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
@@ -535,9 +535,7 @@ def rtn_quantize(
535
535
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
536
536
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
537
537
if algorithm == "k_quant" :
538
- q_weight , scale , zp = quant_tensor_k_quant_cuda (
539
- weight .T , num_bits , group_size
540
- )
538
+ q_weight , scale , zp = quant_tensor_k_quant_cuda (weight .T , num_bits , group_size )
541
539
else :
542
540
q_weight , scale , zp = quant_tensor (
543
541
weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
0 commit comments