Skip to content

Commit e2de15f

Browse files
authored
Add 8-bit quantization support and release 7B model (lm-sys#252)
1 parent d245e37 commit e2de15f

File tree

4 files changed

+140
-12
lines changed

4 files changed

+140
-12
lines changed

README.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ We release [Vicuna](https://vicuna.lmsys.org/) weights as delta weights to compl
5252
You can add our delta to the original LLaMA weights to obtain the Vicuna weights. Instructions:
5353

5454
1. Get the original LLaMA weights in the huggingface format by following the instructions [here](https://huggingface.co/docs/transformers/main/model_doc/llama).
55-
2. Use the following scripts to get Vicuna weights by applying our delta. It will automatically download delta weights from our Hugging Face account.
55+
2. Use the following scripts to get Vicuna weights by applying our delta. They will automatically download delta weights from our Hugging Face account.
5656

5757
**NOTE**:
5858
Our released weights are only compatible with the latest main branch of huggingface/transformers.
@@ -68,12 +68,18 @@ python3 -m fastchat.model.apply_delta \
6868
```
6969

7070
### Vicuna-7B
71-
Coming soon.
71+
This conversion command needs around 30 GB of CPU RAM.
72+
```bash
73+
python3 -m fastchat.model.apply_delta \
74+
--base /path/to/llama-7b \
75+
--target /output/path/to/vicuna-7b \
76+
--delta lmsys/vicuna-7b-delta-v0
77+
```
7278

7379
## Inference with Command Line Interface
7480

7581
### Single GPU
76-
The command below requires around 28GB of GPU memory for Vicuna-13B.
82+
The command below requires around 28GB of GPU memory for Vicuna-13B and 14GB of GPU memory for Vicuna-7B.
7783
```
7884
python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights
7985
```
@@ -85,22 +91,21 @@ python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights --num-gpus 2
8591
```
8692

8793
### CPU Only
88-
This runs on the CPU only and does not require GPU. It requires around 60GB of CPU memory for Vicuna-13B.
94+
This runs on the CPU only and does not require GPU. It requires around 60GB of CPU memory for Vicuna-13B and around 30GB of CPU memory for Vicuna-7B.
8995
```
9096
python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights --device cpu
9197
```
9298

9399
### Metal Backend (Mac computers with Apple silicon or AMD GPUs)
100+
Use `--device mps` to enable GPU acceleration on Mac computers and use `--load-8bit` to turn on 8-bit compression.
94101
```
95-
python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights --device mps
102+
python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights --device mps --load-8bit
96103
```
97104

98105
### Others (Quantization, Low-end Devices, and More Platforms)
99-
100-
You can load in 8-bit mode to reduce GPU memory usage with slightly degraded model quality.
101-
It is tested on a single 4090 and requires around 18GB of GPU memory for Vicuna-13B.
102-
Note that this mode only works on a single GPU.
103-
You are also required to install `bitsandbytes` according to the printed messages.
106+
If you do not have enough memory, you can enable 8-bit compression by adding `--load-8bit` to commands above.
107+
It works with CPU, GPU, and Metal.
108+
This can reduce the memory usage by around half with slightly degraded model quality.
104109

105110
```
106111
python3 -m fastchat.serve.cli --model-name /path/to/vicuna/weights --load-8bit

fastchat/serve/cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
1010

1111
from fastchat.conversation import conv_templates, SeparatorStyle
12+
from fastchat.serve.compression import compress_module
1213
from fastchat.serve.monkey_patch_non_inplace import replace_llama_attn_with_non_inplace_operations
1314

1415

@@ -32,8 +33,8 @@ def load_model(model_name, device, num_gpus, load_8bit=False):
3233
"max_memory": {i: "13GiB" for i in range(num_gpus)},
3334
})
3435
elif device == "mps":
35-
# Avoid bugs in mps backend by not using in-place operations.
3636
kwargs = {"torch_dtype": torch.float16}
37+
# Avoid bugs in mps backend by not using in-place operations.
3738
replace_llama_attn_with_non_inplace_operations()
3839
else:
3940
raise ValueError(f"Invalid device: {device}")
@@ -48,6 +49,12 @@ def load_model(model_name, device, num_gpus, load_8bit=False):
4849
elif device == "mps":
4950
model.to("mps")
5051

52+
if (device == "mps" or device == "cpu") and load_8bit:
53+
compress_module(model)
54+
55+
if args.debug:
56+
print(model)
57+
5158
return model, tokenizer
5259

5360

fastchat/serve/compression.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import dataclasses
2+
3+
import torch
4+
from torch import Tensor
5+
import torch.nn as nn
6+
from torch.nn import functional as F
7+
8+
9+
@dataclasses.dataclass
10+
class CompressionConfig:
11+
"""Group-wise quantization."""
12+
num_bits: int
13+
group_size: int
14+
group_dim: int
15+
symmetric: bool
16+
enabled: bool = True
17+
18+
19+
default_compression_config = CompressionConfig(
20+
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
21+
22+
23+
class CLinear(nn.Module):
24+
def __init__(self, weight, bias):
25+
super().__init__()
26+
27+
self.weight = compress(weight.data, default_compression_config)
28+
self.bias = bias
29+
30+
def forward(self, input: Tensor) -> Tensor:
31+
weight = decompress(self.weight, default_compression_config)
32+
return F.linear(input, weight, self.bias)
33+
34+
35+
def compress_module(module):
36+
for attr_str in dir(module):
37+
target_attr = getattr(module, attr_str)
38+
if type(target_attr) == torch.nn.Linear:
39+
setattr(module, attr_str, CLinear(target_attr.weight, target_attr.bias))
40+
for name, child in module.named_children():
41+
compress_module(child)
42+
43+
44+
def compress(tensor, config):
45+
"""Simulate group-wise quantization."""
46+
if not config.enabled:
47+
return tensor
48+
49+
group_size, num_bits, group_dim, symmetric = (
50+
config.group_size, config.num_bits, config.group_dim, config.symmetric)
51+
assert num_bits <= 8
52+
53+
original_shape = tensor.shape
54+
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
55+
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
56+
original_shape[group_dim+1:])
57+
58+
# Pad
59+
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
60+
if pad_len != 0:
61+
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
62+
tensor = torch.cat([
63+
tensor,
64+
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
65+
dim=group_dim)
66+
data = tensor.view(new_shape)
67+
68+
# Quantize
69+
if symmetric:
70+
B = 2 ** (num_bits - 1) - 1
71+
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
72+
data = data * scale
73+
data = data.clamp_(-B, B).round_().to(torch.int8)
74+
return data, scale, original_shape
75+
else:
76+
B = 2 ** num_bits - 1
77+
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
78+
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
79+
80+
scale = B / (mx - mn)
81+
data = data - mn
82+
data.mul_(scale)
83+
84+
data = data.clamp_(0, B).round_().to(torch.uint8)
85+
return data, mn, scale, original_shape
86+
87+
88+
def decompress(packed_data, config):
89+
"""Simulate group-wise dequantization."""
90+
if not config.enabled:
91+
return packed_data
92+
93+
group_size, num_bits, group_dim, symmetric = (
94+
config.group_size, config.num_bits, config.group_dim, config.symmetric)
95+
96+
# Dequantize
97+
if symmetric:
98+
data, scale, original_shape = packed_data
99+
data = data / scale
100+
else:
101+
data, mn, scale, original_shape = packed_data
102+
data = data / scale
103+
data.add_(mn)
104+
105+
# Unpad
106+
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
107+
if pad_len:
108+
padded_original_shape = (
109+
original_shape[:group_dim] +
110+
(original_shape[group_dim] + pad_len,) +
111+
original_shape[group_dim+1:])
112+
data = data.reshape(padded_original_shape)
113+
indices = [slice(0, x) for x in original_shape]
114+
return data[indices].contiguous()
115+
else:
116+
return data.view(original_shape)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "fschat"
7-
version = "0.1.6"
7+
version = "0.1.7"
88
description = "An open platform for training, serving, and evaluating large language model based chatbots."
99
readme = "README.md"
1010
requires-python = ">=3.8"

0 commit comments

Comments
 (0)