@@ -246,6 +246,168 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246
246
247
247
return q_weight , scale , zero_point
248
248
249
+ def quant_tensor_k_quant_cpu (data , num_bits = 4 , group_size = 32 ):
250
+ """Quantize tensor per group based on k quant.
251
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
252
+
253
+ Args:
254
+ data : input weight
255
+ num_bits (int, optional): num_bits. Defaults to 4.
256
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
257
+
258
+ Returns:
259
+ output: quantized weight
260
+ scale: scale
261
+ zero_point: zero point
262
+ """
263
+ data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
264
+ maxq = 2 ** num_bits - 1
265
+ minq = 0
266
+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
267
+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
268
+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
269
+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
270
+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
271
+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
272
+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
273
+ iscale = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
274
+ mask = rmin != rmax
275
+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
276
+ scale = 1 / iscale
277
+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
278
+ diff = scale * quant_data + rmin - data # (nb, group_size)
279
+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
280
+ nstep = 20
281
+ rdelta = 0.1
282
+ # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
283
+ rrmin = - 1
284
+ for is_ in range (nstep ):
285
+ iscale_new = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
286
+ factor = np .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
287
+ mask = rmin != rmax
288
+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
289
+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
290
+ mul_weights_quant_data_new = weights * quant_data_new
291
+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
292
+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
293
+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
294
+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
295
+
296
+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
297
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
298
+
299
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
300
+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
301
+
302
+ mad_1 = np .array (mad )
303
+ best_mad_1 = np .array (best_mad )
304
+ idx_to_replace = np .where (mad_1 < best_mad_1 )[0 ]
305
+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
306
+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
307
+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
308
+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
309
+
310
+ zero_point = np .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
311
+ scale = scale .astype (np .float64 )
312
+ q_weight = np .empty_like (data , dtype = scale .dtype )
313
+ np .divide (data , scale , out = q_weight )
314
+ np .add (q_weight , zero_point , out = q_weight )
315
+ np .round (q_weight , out = q_weight )
316
+ np .clip (q_weight , minq , maxq , out = q_weight )
317
+
318
+ return q_weight , scale , zero_point
319
+
320
+ def quant_tensor_k_quant_cuda (data , num_bits = 4 , group_size = 32 ):
321
+ """Quantize tensor per group based on k quant.
322
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
323
+
324
+ Args:
325
+ data : input weight
326
+ num_bits (int, optional): num_bits. Defaults to 4.
327
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
328
+
329
+ Returns:
330
+ output: quantized weight
331
+ scale: scale
332
+ zero_point: zero point
333
+ """
334
+ try :
335
+ import cupy as cp
336
+ import torch
337
+ if torch .cuda .is_available ():
338
+ data = cp .asarray (data )
339
+ data = data .reshape ((- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
340
+ nb = data .shape [0 ]
341
+ maxq = 2 ** num_bits - 1
342
+ minq = 0
343
+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
344
+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
345
+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
346
+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
347
+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
348
+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
349
+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
350
+ iscale = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
351
+ mask = rmin != rmax
352
+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
353
+ scale = 1 / iscale
354
+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
355
+ diff = scale * quant_data + rmin - data # (nb, group_size)
356
+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
357
+ nstep = 20
358
+ rdelta = 0.1
359
+ rrmin = - 1
360
+ for is_ in range (nstep ):
361
+ iscale_new = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
362
+ factor = cp .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
363
+ mask = rmin != rmax
364
+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
365
+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
366
+ mul_weights_quant_data_new = weights * quant_data_new
367
+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
368
+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
369
+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
370
+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
371
+
372
+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
373
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
374
+
375
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
376
+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
377
+
378
+ mad_1 = cp .array (mad )
379
+ best_mad_1 = cp .array (best_mad )
380
+ idx_to_replace = np .where (mad_1 < best_mad_1 )[0 ]
381
+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
382
+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
383
+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
384
+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
385
+
386
+ zero_point = np .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
387
+ scale = scale .astype (np .float64 )
388
+ q_weight = np .empty_like (data , dtype = scale .dtype )
389
+ np .divide (data , scale , out = q_weight )
390
+ np .add (q_weight , zero_point , out = q_weight )
391
+ np .round (q_weight , out = q_weight )
392
+ np .clip (q_weight , minq , maxq , out = q_weight )
393
+
394
+ return q_weight .get (), scale .get (), zero_point .get ()
395
+ else :
396
+ logger .warning ("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
397
+ "Fall back to k-quant quantization on CPU." )
398
+ return quant_tensor_k_quant_cpu (
399
+ data , num_bits , group_size
400
+ )
401
+ except ImportError :
402
+ logger .info (
403
+ "Now we are using k-quant quantization on cpu, which is time consuming." \
404
+ "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
405
+ "Please also install torch to check CUDA availablity."
406
+ )
407
+ return quant_tensor_k_quant_cpu (
408
+ data , num_bits , group_size
409
+ )
410
+
249
411
250
412
def qdq_tensor (data , num_bits = 4 , group_size = 32 , scheme = "asym" , dtype = "int" , ratio = 1.0 ):
251
413
"""Quant dequant tensor per group.
@@ -299,6 +461,7 @@ def rtn_quantize(
299
461
ratios = {},
300
462
accuracy_level = 0 ,
301
463
providers = ["CPUExecutionProvider" ],
464
+ algorithm = "rtn" ,
302
465
):
303
466
"""Quant the model with round to nearst method.
304
467
@@ -372,9 +535,15 @@ def rtn_quantize(
372
535
): # pragma: no cover
373
536
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
374
537
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
375
- q_weight , scale , zp = quant_tensor (
376
- weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
377
- )
538
+ if algorithm == "k_quant" :
539
+ q_weight , scale , zp = quant_tensor_k_quant_cuda (
540
+ weight .T , num_bits , group_size
541
+ )
542
+ else :
543
+ q_weight , scale , zp = quant_tensor (
544
+ weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
545
+ )
546
+
378
547
q_matmul_node , new_inits = make_matmul_weight_only_node (
379
548
node = node ,
380
549
weight_shape = org_w_shape ,
0 commit comments