Skip to content

Commit 460c188

Browse files
authored
[Bugfix] Support cpu offloading with fp8 quantization (vllm-project#6960)
1 parent bd70013 commit 460c188

File tree

3 files changed

+116
-33
lines changed

3 files changed

+116
-33
lines changed
Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,44 @@
1-
from vllm.utils import is_hip
1+
import pytest
2+
3+
from tests.quantization.utils import is_quant_method_supported
24

35
from ..utils import compare_two_settings
46

57

68
def test_cpu_offload():
79
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
810
["--cpu-offload-gb", "4"])
9-
if not is_hip():
10-
# compressed-tensors quantization is currently not supported in ROCm.
11-
compare_two_settings(
12-
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
13-
["--cpu-offload-gb", "1"])
11+
12+
13+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
14+
reason="fp8 is not supported on this GPU type.")
15+
def test_cpu_offload_fp8():
16+
# Test quantization of an unquantized checkpoint
17+
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
18+
["--quantization", "fp8"],
19+
["--quantization", "fp8", "--cpu-offload-gb", "2"])
20+
# Test loading a quantized checkpoint
21+
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
22+
["--cpu-offload-gb", "2"])
23+
24+
25+
@pytest.mark.skipif(not is_quant_method_supported("awq"),
26+
reason="awq is not supported on this GPU type.")
27+
def test_cpu_offload_awq():
28+
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
29+
["--cpu-offload-gb", "2"])
30+
31+
32+
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
33+
reason="gptq_marlin is not supported on this GPU type.")
34+
def test_cpu_offload_compressed_tensors():
35+
# Test wNa16
36+
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
37+
["--cpu-offload-gb", "1"])
38+
# Test w4a16_marlin24
39+
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
40+
[], ["--cpu-offload-gb", "1"])
41+
# Test w8a8
42+
compare_two_settings(
43+
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
44+
["--cpu-offload-gb", "1"])

vllm/model_executor/model_loader/loader.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import os
99
from abc import ABC, abstractmethod
10+
from contextlib import contextmanager
1011
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
1112

1213
import huggingface_hub
@@ -37,7 +38,49 @@
3738
supports_vision)
3839
from vllm.model_executor.utils import set_weight_attrs
3940
from vllm.platforms import current_platform
40-
from vllm.utils import is_tpu
41+
from vllm.utils import is_pin_memory_available, is_tpu
42+
43+
44+
@contextmanager
45+
def device_loading_context(module: torch.nn.Module,
46+
target_device: torch.device):
47+
if target_device.type == "cpu":
48+
# If target is CPU, no need to move anything
49+
yield module
50+
return
51+
52+
original_device_states: Dict[str, torch.device] = {}
53+
54+
# Store original device states and move parameters to GPU if they're on CPU
55+
for name, p in module.named_parameters():
56+
if p.device.type == "cpu":
57+
original_device_states[name] = p.device
58+
p.data = p.data.to(target_device)
59+
# Parameters already on target device are not touched
60+
61+
try:
62+
yield module
63+
64+
finally:
65+
# Restore parameters to their original devices, ignoring new parameters
66+
pin_memory = is_pin_memory_available()
67+
for name, p in module.named_parameters():
68+
if name in original_device_states:
69+
original_device: torch.device = original_device_states[name]
70+
if original_device.type == "cpu":
71+
# `torch.empty_like` does not support `pin_memory` argument
72+
cpu_data = torch.empty_strided(size=p.data.size(),
73+
stride=p.data.stride(),
74+
dtype=p.data.dtype,
75+
layout=p.data.layout,
76+
device="cpu",
77+
pin_memory=pin_memory)
78+
cpu_data.copy_(p.data)
79+
p.data = cpu_data
80+
else:
81+
p.data = p.data.to(original_device)
82+
# New parameters or parameters already on target device are untouched
83+
4184

4285
logger = init_logger(__name__)
4386

@@ -275,8 +318,9 @@ def load_model(self, *, model_config: ModelConfig,
275318
parallel_config: ParallelConfig,
276319
scheduler_config: SchedulerConfig,
277320
cache_config: CacheConfig) -> nn.Module:
321+
target_device = torch.device(device_config.device)
278322
with set_default_torch_dtype(model_config.dtype):
279-
with torch.device(device_config.device):
323+
with target_device:
280324
model = _initialize_model(model_config, self.load_config,
281325
lora_config, multimodal_config,
282326
cache_config, scheduler_config)
@@ -291,7 +335,13 @@ def load_model(self, *, model_config: ModelConfig,
291335
for _, module in model.named_modules():
292336
quant_method = getattr(module, "quant_method", None)
293337
if quant_method is not None:
294-
quant_method.process_weights_after_loading(module)
338+
# When quant methods need to process weights after loading
339+
# (for repacking, quantizing, etc), they expect parameters
340+
# to be on the global target device. This scope is for the
341+
# case where cpu offloading is used, where we will move the
342+
# parameters onto device for processing and back off after.
343+
with device_loading_context(module, target_device):
344+
quant_method.process_weights_after_loading(module)
295345
return model.eval()
296346

297347

vllm/model_executor/models/utils.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,42 +87,44 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
8787

8888
# offload parameters to CPU
8989
# use pin_memory if possible, which helps cudagraph capture speed
90+
offloaded_parameters = False
9091
for p in module.parameters():
9192
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
9293
# we use per-parameter offloading
9394
# one module might have some parameters offloaded and some not
9495
break
9596

9697
# `torch.empty_like` does not support `pin_memory` argument
97-
cpu_data = torch.empty(size=p.data.size(),
98-
dtype=p.data.dtype,
99-
layout=p.data.layout,
100-
device='cpu',
101-
pin_memory=pin_memory)
98+
cpu_data = torch.empty_strided(size=p.data.size(),
99+
stride=p.data.stride(),
100+
dtype=p.data.dtype,
101+
layout=p.data.layout,
102+
device='cpu',
103+
pin_memory=pin_memory)
102104
cpu_data.copy_(p.data)
103105
p.data = cpu_data
104106
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
107+
offloaded_parameters = True
108+
109+
if offloaded_parameters:
110+
original_forward = module.forward
111+
112+
def forward(*args, **kwargs):
113+
module.forward = original_forward
114+
device_state = {
115+
# here we blindly call `to(device)`
116+
# if the parameter is already on the device, it will be a no-op
117+
k: v.to(device, non_blocking=True)
118+
for k, v in module.state_dict().items()
119+
}
120+
output = functional_call(module,
121+
device_state,
122+
args=args,
123+
kwargs=kwargs)
124+
module.forward = forward
125+
return output
105126

106-
state_dict: Dict[str, torch.Tensor] = module.state_dict()
107-
108-
original_forward = module.forward
109-
110-
def forward(*args, **kwargs):
111-
module.forward = original_forward
112-
device_state = {
113-
# here we blindly call `to(device)`
114-
# if the parameter is already on the device, it will be a no-op
115-
k: v.to(device, non_blocking=True)
116-
for k, v in state_dict.items()
117-
}
118-
output = functional_call(module,
119-
device_state,
120-
args=args,
121-
kwargs=kwargs)
122127
module.forward = forward
123-
return output
124-
125-
module.forward = forward
126128

127129
return module
128130

0 commit comments

Comments
 (0)