Skip to content

Commit bd70013

Browse files
comaniacyoukaichao
andauthored
[MISC] Introduce pipeline parallelism partition strategies (vllm-project#6920)
Co-authored-by: youkaichao <[email protected]>
1 parent 2ee8d3b commit bd70013

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
3+
import pytest
4+
5+
from vllm.distributed.utils import get_pp_indices
6+
7+
8+
def test_custom_layer_partition():
9+
10+
def _verify(partition_str, num_layers, pp_size, goldens):
11+
bak = os.environ.get("VLLM_PP_LAYER_PARTITION", None)
12+
os.environ["VLLM_PP_LAYER_PARTITION"] = partition_str
13+
for pp_rank, golden in enumerate(goldens):
14+
assert get_pp_indices(num_layers, pp_rank, pp_size) == golden
15+
if bak is not None:
16+
os.environ["VLLM_PP_LAYER_PARTITION"] = bak
17+
18+
# Even partition
19+
_verify("5,5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
20+
# Balanced partition
21+
_verify("4,6,6,4", 20, 4, [(0, 4), (4, 10), (10, 16), (16, 20)])
22+
# Put reminder somewhere
23+
_verify("5,6,5,6", 22, 4, [(0, 5), (5, 11), (11, 16), (16, 22)])
24+
# Invalid partition strings
25+
with pytest.raises(ValueError):
26+
_verify("5,5,5,5,", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
27+
with pytest.raises(ValueError):
28+
_verify("5,5,5,a", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
29+
# Wrong number of partitions
30+
with pytest.raises(ValueError):
31+
_verify("5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
32+
# Wrong number of layers
33+
with pytest.raises(ValueError):
34+
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])

vllm/distributed/utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
import torch
88

9+
import vllm.envs as envs
10+
from vllm.logger import init_logger
11+
12+
logger = init_logger(__name__)
13+
914

1015
def ensure_divisibility(numerator, denominator):
1116
"""Ensure that numerator is divisible by the denominator."""
@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
5459
If the number of layers is not divisible by the number of partitions,
5560
the last partition will have the remaining layers.
5661
"""
57-
layers_per_partition = num_hidden_layers // pp_size
58-
start_layer = pp_rank * layers_per_partition
59-
end_layer = start_layer + layers_per_partition
62+
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
63+
if partition_list_str is not None:
64+
try:
65+
partitions = [
66+
int(layer) for layer in partition_list_str.split(",")
67+
]
68+
except ValueError as err:
69+
raise ValueError("Invalid partition string: {}".format(
70+
partition_list_str)) from err
71+
if len(partitions) != pp_size:
72+
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
73+
if sum(partitions) != num_hidden_layers:
74+
raise ValueError(
75+
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
76+
start_layer = sum(partitions[:pp_rank])
77+
end_layer = start_layer + partitions[pp_rank]
78+
else:
79+
layers_per_partition = num_hidden_layers // pp_size
80+
start_layer = pp_rank * layers_per_partition
81+
end_layer = start_layer + layers_per_partition
6082

61-
if pp_rank == pp_size - 1:
62-
end_layer = num_hidden_layers
83+
if pp_rank == pp_size - 1:
84+
end_layer = num_hidden_layers
6385

6486
return (start_layer, end_layer)

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
2929
VLLM_TRACE_FUNCTION: int = 0
3030
VLLM_ATTENTION_BACKEND: Optional[str] = None
31+
VLLM_PP_LAYER_PARTITION: Optional[str] = None
3132
VLLM_CPU_KVCACHE_SPACE: int = 0
3233
VLLM_CPU_OMP_THREADS_BIND: str = ""
3334
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
@@ -242,6 +243,10 @@ def get_default_config_root():
242243
"VLLM_ATTENTION_BACKEND":
243244
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
244245

246+
# Pipeline stage partition strategy
247+
"VLLM_PP_LAYER_PARTITION":
248+
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
249+
245250
# (CPU backend only) CPU key-value cache space.
246251
# default is 4GB
247252
"VLLM_CPU_KVCACHE_SPACE":

0 commit comments

Comments
 (0)