Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ save*
.log
*.pid
*.ipynb*
.venv/
*.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
calib:
name: pileval
download: False
path: calib data path
n_samples: 128
bs: 1
seq_len: 2048
preproc: txt_general_preproc
seed: *seed
eval:
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
# num_eval_tokens: 3
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_tensor
static: True
kvcache:
method: Naive
bit: 8
symmetric: True
granularity: per_head
head_num: kv head num
save:
save_lightllm_kv_calib: True
lightllm_kv_cache_name: kv_cache_calib.json
save_fake: False
save_path: /path/to/save/
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@ quant:
symmetric: True
granularity: per_tensor
save:
save_lightllm_kv_calib: True
lightllm_kv_cache_name: kv_cache_calib.json
save_fake: False
save_path: /path/to/save/
save_path: /path/to/save/
25 changes: 25 additions & 0 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llmc.models import *
from llmc.utils import (check_config, deploy_all_modality, get_modality,
mkdirs, print_important_package_version, seed_all,
collect_lightllm_kv_calib_json,
update_autoawq_quant_config,
update_lightx2v_quant_config, update_vllm_quant_config)
from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY
Expand Down Expand Up @@ -72,6 +73,21 @@ def main(config):

eval_model(model, blockwise_opts, eval_list, eval_pos='transformed')
if int(os.environ['RANK']) == 0:
if 'save' in config and config.save.get('save_lightllm_kv_cache_calib', False):
calib_json_list = [
collect_lightllm_kv_calib_json(blockwise_opt)
for blockwise_opt in blockwise_opts
if hasattr(blockwise_opt, 'quant_kvcache')
]
calib_json_payload = (
calib_json_list[0] if len(calib_json_list) == 1 else calib_json_list
)
with open(save_lightllm_kv_cache_calib_path, 'w') as file:
json.dump(calib_json_payload, file, ensure_ascii=False, indent=4)
logger.info(
f'save lightllm kv cache calib done -- {save_lightllm_kv_cache_calib_path}'
)

if 'save' in config and config.save.get('save_trans', False):
blockwise_opt.save_model(save_trans_path)

Expand Down Expand Up @@ -209,6 +225,14 @@ def main(config):
# Ensure only the main process creates directories
if int(os.environ['RANK']) == 0:
if 'save' in config:
if config.save.get('save_lightllm_kv_cache_calib', False):
mkdirs(config.save.save_path)
save_lightllm_kv_cache_calib_path = os.path.join(
config.save.save_path,
config.save.get(
'lightllm_kv_cache_calib_name', 'kv_cache_calib.json'
),
)
if config.save.get('save_trans', False):
save_trans_path = os.path.join(
config.save.save_path, 'transformed_model'
Expand Down Expand Up @@ -266,3 +290,4 @@ def main(config):
llmc_duration_time = llmc_end_time - llmc_start_time
logger.info(f'llmc_duration_time: {llmc_duration_time} s')
logger.info('--- llmc finished ---')

16 changes: 11 additions & 5 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import functools
import gc
import json
import os
import re
from collections import defaultdict
Expand Down Expand Up @@ -175,13 +174,18 @@ def set_quant_config(self):
self.act_quant_module = IntegerQuantizer
elif quant_type == 'float-quant':
self.act_quant_module = FloatQuantizer
self.quant_config['act']['tp'] = self.tp
self.aquantizer = self.act_quant_module(**self.quant_config['act'])
self.act_static = self.quant_config['act'].get('static', False)
if self.act_static:
assert (
self.quant_config['act']['granularity'] == 'per_tensor'
), 'Only support per_tensor static quant'
# Static activation quantization uses the batched calibration
# path, so normalize the default minmax setting to
# static_minmax to match the downstream calibration logic.
if self.quant_config['act'].get('calib_algo', 'minmax') == 'minmax':
self.quant_config['act']['calib_algo'] = 'static_minmax'
self.quant_config['act']['tp'] = self.tp
self.aquantizer = self.act_quant_module(**self.quant_config['act'])
self.quant_attn = self.quant_config['act'].get('quant_attn', False)
if self.quant_attn:
assert self.config['model']['type'] in ['Vit', 'DeepseekV2']
Expand All @@ -203,8 +207,10 @@ def set_quant_config(self):
kv_special_cfg = self.quant_config['kvcache'].get('special', {})
act_static_cfg = {}
if self.act_static:
act_static_cfg.update(self.config.calib.n_sample)
act_static_cfg.update(self.config.calib.bs)
# The KV cache constructor expects num_samples / bsz, so map
# the calibration config fields to the parameter names it uses.
act_static_cfg['num_samples'] = self.config.calib.n_samples
act_static_cfg['bsz'] = self.config.calib.bs
kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
kv_quant_type, self.quant_config['kvcache'],
Expand Down
11 changes: 10 additions & 1 deletion llmc/compression/quantization/kvquant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import torch
from loguru import logger
from transformers import DynamicCache
Expand All @@ -12,12 +13,20 @@ class NaiveQuantKVCache(DynamicCache):
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, bsz=1):
super().__init__()

assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
# Copy the config to avoid mutating the original quantization config in static KV calibration.
kvquant_cfg = copy.deepcopy(kvquant_cfg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方为啥要deep copy一份?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just avoid mutating the original quantization config.

assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group', 'per_head']
self.num_hidden_layers, self.num_samples, self.bsz = (
num_hidden_layers,
num_samples,
bsz,
)
if kvquant_cfg.get('static', False) and kvquant_cfg.get(
'calib_algo', 'minmax'
) == 'minmax':
# Static KV calibration uses the batched tensor statistics path, so convert the default
# minmax setting to static_minmax here to avoid a later calibration algo name mismatch.
kvquant_cfg['calib_algo'] = 'static_minmax'
if quant_type == 'int-quant':
self.kvquantizer = IntegerQuantizer(**kvquant_cfg)
elif quant_type == 'float-quant':
Expand Down
19 changes: 8 additions & 11 deletions llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,27 +224,24 @@ def get_minmax_stats(self, act_tensors):
for tensor in tensors:
tensor = self.reshape_tensor(tensor)
tensor_range = self.get_minmax_range(tensor)
min_val, max_val = tensor_range[0], tensor_range[1]
min_val = tensor_range[0].detach().cpu().to(torch.float32)
max_val = tensor_range[1].detach().cpu().to(torch.float32)

if input_idx not in stats_min_max:
stats_min_max[input_idx] = {}
stats_min_max[input_idx]['min'] = torch.tensor(
[min_val], dtype=torch.float32
)
stats_min_max[input_idx]['max'] = torch.tensor(
[max_val], dtype=torch.float32
)
stats_min_max[input_idx]['min'] = min_val.unsqueeze(0)
stats_min_max[input_idx]['max'] = max_val.unsqueeze(0)
else:
stats_min_max[input_idx]['min'] = torch.cat(
[
stats_min_max[input_idx]['min'],
torch.tensor([min_val], dtype=torch.float32),
min_val.unsqueeze(0),
]
)
stats_min_max[input_idx]['max'] = torch.cat(
[
stats_min_max[input_idx]['max'],
torch.tensor([max_val], dtype=torch.float32),
max_val.unsqueeze(0),
]
)

Expand All @@ -255,8 +252,8 @@ def get_static_minmax_range(self, act_tensors):
stats_min_max = self.get_minmax_stats(act_tensors)
min_vals, max_vals = [], []
for input_idx, tensor_range in stats_min_max.items():
min_val = tensor_range['min'].mean()
max_val = tensor_range['max'].mean()
min_val = tensor_range['min'].mean(dim=0)
max_val = tensor_range['max'].mean(dim=0)
min_vals.append(min_val)
max_vals.append(max_val)

Expand Down
1 change: 1 addition & 0 deletions llmc/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .export_autoawq import update_autoawq_quant_config
from .export_calib import collect_lightllm_kv_calib_json
from .export_lightx2v import update_lightx2v_quant_config
from .export_vllm import update_vllm_quant_config
from .utils import (check_config, copy_files, deploy_all_modality,
Expand Down
98 changes: 98 additions & 0 deletions llmc/utils/export_calib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch


def _to_jsonable(value):
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
return value


def _to_tensor(value, dtype=torch.float32):
if isinstance(value, torch.Tensor):
return value.detach().cpu().to(dtype)
return torch.as_tensor(value, dtype=dtype)


def _collect_lightllm_kv_scale(scales, zeros, qmin, qmax):
if isinstance(scales, torch.Tensor) and scales.numel() == 0:
return None

scales_tensor = _to_tensor(scales)
zeros_tensor = _to_tensor(zeros, dtype=scales_tensor.dtype)
qmin_tensor = _to_tensor(qmin, dtype=scales_tensor.dtype)
qmax_tensor = _to_tensor(qmax, dtype=scales_tensor.dtype)
min_tensor = (qmin_tensor - zeros_tensor) * scales_tensor
max_tensor = (qmax_tensor - zeros_tensor) * scales_tensor
absmax_tensor = torch.maximum(min_tensor.abs(), max_tensor.abs())
fp8_qmax = torch.tensor(
torch.finfo(torch.float8_e4m3fn).max, dtype=absmax_tensor.dtype
)
return absmax_tensor / fp8_qmax


def collect_lightllm_kv_calib_json(blockwise_opt):
if not getattr(blockwise_opt, 'quant_kvcache', False):
raise ValueError(
'save_lightllm_kv_cache_calib requires kvcache quantization.'
)

kv_cfg = blockwise_opt.quant_config['kvcache']
granularity = kv_cfg.get('granularity')
if granularity not in ['per_tensor', 'per_head']:
raise ValueError(
f'LightLLM calib export only supports per_tensor/per_head, got {granularity}'
)

num_layers = blockwise_opt.model.model_config.num_hidden_layers
num_head = int(
getattr(
blockwise_opt.model.model_config,
'num_key_value_heads',
blockwise_opt.model.get_num_attention_heads(),
)
)
scales = []
for layer_idx in range(num_layers):
key_scale = _collect_lightllm_kv_scale(
blockwise_opt.kv_module.k_scales_buffer[layer_idx],
blockwise_opt.kv_module.k_zeros_buffer[layer_idx],
blockwise_opt.kv_module.k_qmin_buffer[layer_idx],
blockwise_opt.kv_module.k_qmax_buffer[layer_idx],
)
value_scale = _collect_lightllm_kv_scale(
blockwise_opt.kv_module.v_scales_buffer[layer_idx],
blockwise_opt.kv_module.v_zeros_buffer[layer_idx],
blockwise_opt.kv_module.v_qmin_buffer[layer_idx],
blockwise_opt.kv_module.v_qmax_buffer[layer_idx],
)
if key_scale is None or value_scale is None:
raise ValueError(f'Calibration scale for layer {layer_idx} is empty.')

scale_row = torch.cat([key_scale.reshape(-1), value_scale.reshape(-1)]).tolist()
scales.append(scale_row)

scale_width = len(scales[0]) if scales else 0
if granularity == 'per_tensor' and scale_width != 2:
raise ValueError(f'per_tensor export expects 2 scales per layer, got {scale_width}')
if granularity == 'per_head' and scale_width != num_head * 2:
raise ValueError(
f'per_head export expects {num_head * 2} scales per layer, got {scale_width}'
)

architectures = getattr(blockwise_opt.model.model_config, 'architectures', None)
if isinstance(architectures, list) and len(architectures) > 0:
architectures = architectures[0]
elif architectures is None:
architectures = blockwise_opt.config.model.type

return {
'version': '1.0',
'architectures': architectures,
'quant_type': granularity,
'qmin': float(torch.finfo(torch.float8_e4m3fn).min),
'qmax': float(torch.finfo(torch.float8_e4m3fn).max),
'num_layers': num_layers,
'num_head': num_head,
'scales_shape': [num_layers, scale_width],
'scales': _to_jsonable(scales),
}
Loading