Skip to content

Commit 9855b99

Browse files
authored
[Feature][kernel] tensor parallelism with bitsandbytes quantization (vllm-project#8434)
1 parent 1009e93 commit 9855b99

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

tests/quantization/test_bitsandbytes.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
6464
model_name)
6565

6666

67+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
68+
reason='Test requires at least 2 GPUs.')
69+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
70+
reason='bitsandbytes is not supported on this GPU type.')
71+
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
72+
@fork_new_process_for_each_test
73+
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
74+
model_name, description) -> None:
75+
76+
hf_model_kwargs = {"load_in_4bit": True}
77+
validate_generated_texts(hf_runner,
78+
vllm_runner,
79+
example_prompts[:1],
80+
model_name,
81+
hf_model_kwargs,
82+
vllm_tp_size=2)
83+
84+
6785
def log_generated_texts(prompts, outputs, runner_name):
6886
logged_texts = []
6987
for i, (_, generated_text) in enumerate(outputs):
@@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
8098
vllm_runner,
8199
prompts,
82100
model_name,
83-
hf_model_kwargs=None):
101+
hf_model_kwargs=None,
102+
vllm_tp_size=1):
84103

85104
# NOTE: run vLLM first, as it requires a clean process
86105
# when using distributed inference
87-
88-
#Run with vLLM runner
89106
with vllm_runner(model_name,
90107
quantization='bitsandbytes',
91108
load_format='bitsandbytes',
109+
tensor_parallel_size=vllm_tp_size,
92110
enforce_eager=True,
93111
gpu_memory_utilization=0.8) as llm:
94112
vllm_outputs = llm.generate_greedy(prompts, 8)
95113
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
96114

97115
# Clean up the GPU memory for the next test
98-
torch.cuda.synchronize()
99116
gc.collect()
100117
torch.cuda.empty_cache()
101118

@@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
108125
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
109126

110127
# Clean up the GPU memory for the next test
111-
torch.cuda.synchronize()
112128
gc.collect()
113129
torch.cuda.empty_cache()
114130

vllm/config.py

-6
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,6 @@ def verify_with_parallel_config(
393393
"Pipeline parallelism is only supported for the following "
394394
f" architectures: {_PP_SUPPORTED_MODELS}.")
395395

396-
if self.quantization == "bitsandbytes" and (
397-
parallel_config.tensor_parallel_size > 1
398-
or parallel_config.pipeline_parallel_size > 1):
399-
raise ValueError(
400-
"BitAndBytes quantization with TP or PP is not supported yet.")
401-
402396
# Remove the constraint after the bitsandbytes issue is fixed:
403397
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
404398
if self.quantization == "bitsandbytes" and self.enforce_eager is False:

vllm/model_executor/layers/linear.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,11 @@ def weight_loader(self,
530530
param_data = param_data.narrow(output_dim, shard_offset,
531531
shard_size)
532532
start_idx = tp_rank * shard_size
533-
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
534-
shard_size)
533+
# bitsandbytes loads the weights of the specific portion
534+
# no need to narrow here
535+
if not use_bitsandbytes_4bit:
536+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
537+
shard_size)
535538
# Special case for AQLM codebooks.
536539
elif is_metadata:
537540
# metadata indicates fixed size concatenated along dim 0
@@ -899,8 +902,13 @@ def weight_loader(self,
899902
else:
900903
shard_id = tp_rank // self.num_kv_head_replicas
901904
start_idx = shard_id * shard_size
902-
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
903-
shard_size)
905+
906+
# bitsandbytes loads the weights of the specific portion
907+
# no need to narrow here
908+
if not use_bitsandbytes_4bit:
909+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
910+
shard_size)
911+
904912
# Special case for for AQLM codebooks.
905913
elif is_metadata:
906914
# metadata indicates fixed size concatenated along dim 0
@@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
10001008
tp_rank = get_tensor_model_parallel_rank()
10011009
tp_size = get_tensor_model_parallel_world_size()
10021010
input_dim = getattr(param, "input_dim", None)
1011+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
10031012

10041013
# Special case for GGUF
10051014
is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
10151024
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
10161025

10171026
param_data = param.data
1018-
if input_dim is not None:
1027+
# bitsandbytes loads the weights of the specific portion
1028+
# no need to narrow here
1029+
if input_dim is not None and not use_bitsandbytes_4bit:
10191030
shard_size = param_data.shape[input_dim]
10201031
start_idx = tp_rank * shard_size
10211032
loaded_weight = loaded_weight.narrow(input_dim, start_idx,

vllm/model_executor/model_loader/loader.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
2323
LoRAConfig, ModelConfig, MultiModalConfig,
2424
ParallelConfig, SchedulerConfig)
25+
from vllm.distributed import (get_tensor_model_parallel_rank,
26+
get_tensor_model_parallel_world_size)
2527
from vllm.envs import VLLM_USE_MODELSCOPE
2628
from vllm.logger import init_logger
2729
from vllm.model_executor.layers.quantization.base_config import (
@@ -689,6 +691,8 @@ def save_model(
689691
class BitsAndBytesModelLoader(BaseModelLoader):
690692
"""Model loader to load model weights with BitAndBytes quantization."""
691693

694+
# TODO: these module names are for Llama only,
695+
# change so that it works with other models as well
692696
default_target_modules = [
693697
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
694698
"o_proj"
@@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str,
911915
def _unquantized_generator(self, hf_weights_files, use_safetensors,
912916
quant_state_dict) -> Generator:
913917
from bitsandbytes.functional import quantize_4bit
918+
tp_size = get_tensor_model_parallel_world_size()
919+
tp_rank = get_tensor_model_parallel_rank()
920+
914921
for weight_name, weight_tensor in self._hf_weight_iter(
915922
hf_weights_files, use_safetensors):
916923
if any(target_module in weight_name
917924
for target_module in self.target_modules):
918925
weight_name = weight_name.replace(".weight", ".qweight")
926+
927+
# weight partitions of different modules occur at
928+
# different dimensions
929+
# TODO: these module names are for Llama only,
930+
# change so that it works with other models as well
931+
if 'down_proj' in weight_name or 'o_proj' in weight_name:
932+
total_size = weight_tensor.size(-1)
933+
start_index = total_size // tp_size * tp_rank
934+
end_index = total_size // tp_size * (tp_rank + 1)
935+
weight_sub_tensor = weight_tensor[...,
936+
start_index:end_index]
937+
938+
else:
939+
total_size = weight_tensor.size(0)
940+
start_index = total_size // tp_size * tp_rank
941+
end_index = total_size // tp_size * (tp_rank + 1)
942+
weight_sub_tensor = weight_tensor[start_index:end_index,
943+
...]
944+
919945
# bitsandbytes requires data in GPU
920-
loaded_weight = weight_tensor.cuda().data
946+
if weight_sub_tensor.is_cuda:
947+
loaded_weight = weight_sub_tensor
948+
else:
949+
loaded_weight = weight_sub_tensor.cuda()
950+
951+
# remove the following after the issue is fixed:
952+
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
953+
if loaded_weight.is_contiguous() is False:
954+
loaded_weight = loaded_weight.contiguous()
955+
921956
with set_default_torch_dtype(torch.float32):
922957
processed_weight, quant_state = quantize_4bit(
923958
loaded_weight,
@@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig,
958993
f"BitsAndBytes loader does not support {quant_method} "
959994
"quantization")
960995

996+
# The quant_states in pre_quantized models cannot work with a split
997+
# weight tensor. So TP does not work with pre_quantized bnb models.
998+
if pre_quant and get_tensor_model_parallel_world_size() > 1:
999+
raise ValueError(
1000+
"Prequant BitsAndBytes models with TP is not supported."
1001+
"Please try with PP.")
1002+
9611003
load_8bit = False
9621004
if pre_quant:
9631005
load_8bit = quant_config.get('load_in_8bit', False)

0 commit comments

Comments
 (0)