Skip to content

Commit 5aee8bb

Browse files
committed
k quant
1 parent e9bd2e7 commit 5aee8bb

File tree

1 file changed

+172
-3
lines changed

1 file changed

+172
-3
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+172-3
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,168 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246246

247247
return q_weight, scale, zero_point
248248

249+
def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
250+
"""Quantize tensor per group based on k quant.
251+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
252+
253+
Args:
254+
data : input weight
255+
num_bits (int, optional): num_bits. Defaults to 4.
256+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
257+
258+
Returns:
259+
output: quantized weight
260+
scale: scale
261+
zero_point: zero point
262+
"""
263+
data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size)
264+
maxq = 2**num_bits - 1
265+
minq = 0
266+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
267+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
268+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
269+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
270+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
271+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
272+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
273+
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
274+
mask = rmin != rmax
275+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
276+
scale = 1 / iscale
277+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
278+
diff = scale * quant_data + rmin - data # (nb, group_size)
279+
best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
280+
nstep = 20
281+
rdelta = 0.1
282+
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
283+
rrmin = -1
284+
for is_ in range(nstep):
285+
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
286+
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
287+
mask = rmin != rmax
288+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
289+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
290+
mul_weights_quant_data_new = weights * quant_data_new
291+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
292+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
293+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
294+
D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
295+
296+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
297+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
298+
299+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
300+
mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
301+
302+
mad_1 = np.array(mad)
303+
best_mad_1 = np.array(best_mad)
304+
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
305+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
306+
best_mad[idx_to_replace] = mad[idx_to_replace]
307+
scale[idx_to_replace] = this_scale[idx_to_replace]
308+
rmin[idx_to_replace] = this_min[idx_to_replace]
309+
310+
zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
311+
scale = scale.astype(np.float64)
312+
q_weight = np.empty_like(data, dtype=scale.dtype)
313+
np.divide(data, scale, out=q_weight)
314+
np.add(q_weight, zero_point, out=q_weight)
315+
np.round(q_weight, out=q_weight)
316+
np.clip(q_weight, minq, maxq, out=q_weight)
317+
318+
return q_weight, scale, zero_point
319+
320+
def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
321+
"""Quantize tensor per group based on k quant.
322+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
323+
324+
Args:
325+
data : input weight
326+
num_bits (int, optional): num_bits. Defaults to 4.
327+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
328+
329+
Returns:
330+
output: quantized weight
331+
scale: scale
332+
zero_point: zero point
333+
"""
334+
try:
335+
import cupy as cp
336+
import torch
337+
if torch.cuda.is_available():
338+
data = cp.asarray(data)
339+
data = data.reshape((-1, group_size)).astype(np.float32) # (nb, group_size)
340+
nb = data.shape[0]
341+
maxq = 2**num_bits - 1
342+
minq = 0
343+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
344+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
345+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
346+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
347+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
348+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
349+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
350+
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
351+
mask = rmin != rmax
352+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
353+
scale = 1 / iscale
354+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
355+
diff = scale * quant_data + rmin - data # (nb, group_size)
356+
best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
357+
nstep = 20
358+
rdelta = 0.1
359+
rrmin = -1
360+
for is_ in range(nstep):
361+
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
362+
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
363+
mask = rmin != rmax
364+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
365+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
366+
mul_weights_quant_data_new = weights * quant_data_new
367+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
368+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
369+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
370+
D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
371+
372+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
373+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
374+
375+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
376+
mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
377+
378+
mad_1 = cp.array(mad)
379+
best_mad_1 = cp.array(best_mad)
380+
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
381+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
382+
best_mad[idx_to_replace] = mad[idx_to_replace]
383+
scale[idx_to_replace] = this_scale[idx_to_replace]
384+
rmin[idx_to_replace] = this_min[idx_to_replace]
385+
386+
zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
387+
scale = scale.astype(np.float64)
388+
q_weight = np.empty_like(data, dtype=scale.dtype)
389+
np.divide(data, scale, out=q_weight)
390+
np.add(q_weight, zero_point, out=q_weight)
391+
np.round(q_weight, out=q_weight)
392+
np.clip(q_weight, minq, maxq, out=q_weight)
393+
394+
return q_weight.get(), scale.get(), zero_point.get()
395+
else:
396+
logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
397+
"Fall back to k-quant quantization on CPU.")
398+
return quant_tensor_k_quant_cpu(
399+
data, num_bits, group_size
400+
)
401+
except ImportError:
402+
logger.info(
403+
"Now we are using k-quant quantization on cpu, which is time consuming." \
404+
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
405+
"Please also install torch to check CUDA availablity."
406+
)
407+
return quant_tensor_k_quant_cpu(
408+
data, num_bits, group_size
409+
)
410+
249411

250412
def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
251413
"""Quant dequant tensor per group.
@@ -299,6 +461,7 @@ def rtn_quantize(
299461
ratios={},
300462
accuracy_level=0,
301463
providers=["CPUExecutionProvider"],
464+
algorithm="rtn",
302465
):
303466
"""Quant the model with round to nearst method.
304467
@@ -372,9 +535,15 @@ def rtn_quantize(
372535
): # pragma: no cover
373536
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
374537
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
375-
q_weight, scale, zp = quant_tensor(
376-
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
377-
)
538+
if algorithm == "k_quant":
539+
q_weight, scale, zp = quant_tensor_k_quant_cuda(
540+
weight.T, num_bits, group_size
541+
)
542+
else:
543+
q_weight, scale, zp = quant_tensor(
544+
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
545+
)
546+
378547
q_matmul_node, new_inits = make_matmul_weight_only_node(
379548
node=node,
380549
weight_shape=org_w_shape,

0 commit comments

Comments
 (0)