forked from xorbitsai/inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEAT: support fp4 and int8 quantization for pytorch model (xorbitsai#238
- Loading branch information
Showing
8 changed files
with
318 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import dataclasses | ||
import gc | ||
import glob | ||
import os | ||
|
||
import torch | ||
import torch.nn as nn | ||
from accelerate import init_empty_weights | ||
from accelerate.utils import set_module_tensor_to_device | ||
from huggingface_hub import snapshot_download | ||
from torch import Tensor | ||
from torch.nn import functional as F | ||
from tqdm import tqdm | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
from ....constants import XINFERENCE_CACHE_DIR | ||
|
||
|
||
@dataclasses.dataclass | ||
class CompressionConfig: | ||
"""Group-wise quantization.""" | ||
|
||
num_bits: int | ||
group_size: int | ||
group_dim: int | ||
symmetric: bool | ||
enabled: bool = True | ||
|
||
|
||
default_compression_config = CompressionConfig( | ||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True | ||
) | ||
|
||
|
||
class CLinear(nn.Module): | ||
"""Compressed Linear Layer.""" | ||
|
||
def __init__(self, weight=None, bias=None, device=None): | ||
super().__init__() | ||
if weight is None: | ||
self.weight = None | ||
elif isinstance(weight, Tensor): | ||
self.weight = compress(weight.data.to(device), default_compression_config) | ||
else: | ||
self.weight = weight | ||
self.bias = bias | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
weight = decompress(self.weight, default_compression_config) | ||
if self.bias is None: | ||
return F.linear(input.to(weight.dtype), weight) | ||
return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) | ||
|
||
|
||
def get_compressed_list(module, prefix=""): | ||
compressed_list = [] | ||
for attr_str in dir(module): | ||
target_attr = getattr(module, attr_str) | ||
if type(target_attr) == torch.nn.Linear: | ||
full_name = ( | ||
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" | ||
) | ||
compressed_list.append(full_name) | ||
for name, child in module.named_children(): | ||
child_prefix = f"{prefix}.{name}" if prefix else name | ||
for each in get_compressed_list(child, child_prefix): | ||
compressed_list.append(each) | ||
return compressed_list | ||
|
||
|
||
def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): | ||
for attr_str in dir(module): | ||
target_attr = getattr(module, attr_str) | ||
if type(target_attr) == torch.nn.Linear: | ||
full_name = ( | ||
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" | ||
) | ||
setattr( | ||
module, | ||
attr_str, | ||
CLinear( | ||
compressed_state_dict[full_name], target_attr.bias, target_device | ||
), | ||
) | ||
for name, child in module.named_children(): | ||
child_prefix = f"{prefix}.{name}" if prefix else name | ||
apply_compressed_weight( | ||
child, compressed_state_dict, target_device, child_prefix | ||
) | ||
|
||
|
||
def load_compress_model( | ||
model_path: str, | ||
device: str, | ||
torch_dtype: torch.dtype, | ||
use_fast: bool, | ||
revision: str = "main", | ||
): | ||
# partially load model | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_path, | ||
use_fast=use_fast, | ||
trust_remote_code=True, | ||
revision=revision, | ||
cache_dir=XINFERENCE_CACHE_DIR, | ||
) | ||
|
||
with init_empty_weights(): | ||
config = AutoConfig.from_pretrained( | ||
model_path, | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch_dtype, | ||
trust_remote_code=True, | ||
revision=revision, | ||
cache_dir=XINFERENCE_CACHE_DIR, | ||
) | ||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) | ||
linear_weights = get_compressed_list(model) | ||
|
||
if os.path.exists(model_path): | ||
# `model_path` is a local folder | ||
base_pattern = os.path.join(model_path, "pytorch_model*.bin") | ||
else: | ||
# `model_path` is a cached Hugging Face repo | ||
model_path = snapshot_download( | ||
model_path, revision=revision, cache_dir=XINFERENCE_CACHE_DIR | ||
) | ||
base_pattern = os.path.join(model_path, "pytorch_model*.bin") | ||
|
||
files = glob.glob(base_pattern) | ||
|
||
compressed_state_dict = {} | ||
|
||
for filename in tqdm(files): | ||
tmp_state_dict = torch.load(filename, map_location=torch.device(device)) | ||
for name in tmp_state_dict: | ||
if name in linear_weights: | ||
tensor = tmp_state_dict[name].to(device).data.to(torch_dtype) | ||
compressed_state_dict[name] = compress( | ||
tensor, default_compression_config | ||
) | ||
else: | ||
compressed_state_dict[name] = tmp_state_dict[name].to(device) | ||
tmp_state_dict[name] = None | ||
tensor = None | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
for name in model.state_dict(): | ||
if name not in linear_weights: | ||
set_module_tensor_to_device( | ||
model, name, device, value=compressed_state_dict[name] | ||
) | ||
apply_compressed_weight(model, compressed_state_dict, device) | ||
|
||
model.to(device) | ||
|
||
return model, tokenizer | ||
|
||
|
||
def compress(tensor, config): | ||
"""Simulate group-wise quantization.""" | ||
if not config.enabled: | ||
return tensor | ||
|
||
group_size, num_bits, group_dim, symmetric = ( | ||
config.group_size, | ||
config.num_bits, | ||
config.group_dim, | ||
config.symmetric, | ||
) | ||
assert num_bits <= 8 | ||
|
||
original_shape = tensor.shape | ||
num_groups = (original_shape[group_dim] + group_size - 1) // group_size | ||
new_shape = ( | ||
original_shape[:group_dim] | ||
+ (num_groups, group_size) | ||
+ original_shape[group_dim + 1 :] | ||
) | ||
|
||
# Pad | ||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size | ||
if pad_len != 0: | ||
pad_shape = ( | ||
original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] | ||
) | ||
tensor = torch.cat( | ||
[tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], | ||
dim=group_dim, | ||
) | ||
data = tensor.view(new_shape) | ||
|
||
# Quantize | ||
if symmetric: | ||
B = 2 ** (num_bits - 1) - 1 | ||
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] | ||
data = data * scale | ||
data = data.clamp_(-B, B).round_().to(torch.int8) | ||
return data, scale, original_shape | ||
else: | ||
B = 2**num_bits - 1 | ||
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] | ||
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] | ||
|
||
scale = B / (mx - mn) | ||
data = data - mn | ||
data.mul_(scale) | ||
|
||
data = data.clamp_(0, B).round_().to(torch.uint8) | ||
return data, mn, scale, original_shape | ||
|
||
|
||
def decompress(packed_data, config): | ||
"""Simulate group-wise dequantization.""" | ||
if not config.enabled: | ||
return packed_data | ||
|
||
group_size, _, group_dim, symmetric = ( | ||
config.group_size, | ||
config.num_bits, | ||
config.group_dim, | ||
config.symmetric, | ||
) | ||
|
||
# Dequantize | ||
if symmetric: | ||
data, scale, original_shape = packed_data | ||
data = data / scale | ||
else: | ||
data, mn, scale, original_shape = packed_data | ||
data = data / scale | ||
data.add_(mn) | ||
|
||
# Unpad | ||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size | ||
if pad_len: | ||
padded_original_shape = ( | ||
original_shape[:group_dim] | ||
+ (original_shape[group_dim] + pad_len,) | ||
+ original_shape[group_dim + 1 :] | ||
) | ||
data = data.reshape(padded_original_shape) | ||
indices = [slice(0, x) for x in original_shape] | ||
return data[indices].contiguous() | ||
else: | ||
return data.view(original_shape) |
Oops, something went wrong.