Skip to content

Commit 99aa4ed

Browse files
authored
[torch.compile] register allreduce operations as custom ops (vllm-project#8526)
1 parent ee2bcea commit 99aa4ed

File tree

9 files changed

+137
-50
lines changed

9 files changed

+137
-50
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,6 @@ steps:
163163
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
164164
- python3 offline_inference_encoder_decoder.py
165165

166-
- label: torch compile integration test
167-
source_file_dependencies:
168-
- vllm/
169-
commands:
170-
- pytest -v -s ./compile/test_full_graph.py
171-
- pytest -v -s ./compile/test_wrapper.py
172-
173166
- label: Prefix Caching Test # 7min
174167
#mirror_hardwares: [amd]
175168
source_file_dependencies:
@@ -348,7 +341,10 @@ steps:
348341
- vllm/executor/
349342
- vllm/model_executor/models/
350343
- tests/distributed/
344+
- vllm/compilation
351345
commands:
346+
- pytest -v -s ./compile/test_full_graph.py
347+
- pytest -v -s ./compile/test_wrapper.py
352348
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
353349
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
354350
# Avoid importing model tests that cause CUDA reinitialization error

csrc/custom_all_reduce.cu

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
5555
t.numel() * t.element_size());
5656
}
5757

58-
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
59-
bool full_nvlink) {
60-
auto inp_size = inp.numel() * inp.element_size();
61-
// custom allreduce requires input byte size to be multiples of 16
62-
if (inp_size % 16 != 0) return false;
63-
if (!_is_weak_contiguous(inp)) return false;
64-
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65-
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
66-
// performance improvement over NCCL.
67-
return false;
68-
}
69-
7058
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
7159
cudaStream_t stream) {
7260
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);

csrc/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
241241
const std::vector<std::string>& handles,
242242
const std::vector<int64_t>& offsets, int64_t rank,
243243
bool full_nvlink);
244-
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
245-
bool full_nvlink);
246244
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
247245
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
248246
torch::Tensor& out);

csrc/torch_bindings.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
411411
"bool full_nvlink) -> int");
412412
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
413413

414-
custom_ar.def(
415-
"should_custom_ar(Tensor inp, int max_size, int world_size, "
416-
"bool full_nvlink) -> bool");
417-
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
418-
419414
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
420415
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
421416

tests/compile/__init__.py

Whitespace-only changes.

tests/compile/test_full_graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,20 @@
22

33
import pytest
44

5+
from vllm.utils import cuda_device_count_stateless
6+
7+
from ..utils import fork_new_process_for_each_test
8+
59

610
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
7-
def test_full_graph(model):
11+
@pytest.mark.parametrize("tp_size", [1, 2])
12+
@fork_new_process_for_each_test
13+
def test_full_graph(model, tp_size):
14+
15+
# Skip the test if there are not enough CUDA devices.
16+
if cuda_device_count_stateless() < tp_size:
17+
pytest.skip("Not enough CUDA devices for the test.")
18+
819
# make sure these models can be captured in full graph mode
920
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
1021
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
@@ -17,7 +28,7 @@ def test_full_graph(model):
1728
"The future of AI is",
1829
]
1930
sampling_params = SamplingParams(temperature=0)
20-
llm = LLM(model=model, enforce_eager=True)
31+
llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
2132

2233
outputs = llm.generate(prompts, sampling_params)
2334

vllm/_custom_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
870870
offsets, rank, full_nvlink)
871871

872872

873-
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
874-
full_nvlink: bool) -> bool:
875-
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
876-
full_nvlink)
877-
878-
879873
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
880874
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
881875

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
3333
return True
3434

3535

36+
def is_weak_contiguous(inp: torch.Tensor):
37+
return inp.is_contiguous() or (inp.storage().nbytes() -
38+
inp.storage_offset() * inp.element_size()
39+
== inp.numel() * inp.element_size())
40+
41+
3642
class CustomAllreduce:
3743

3844
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@@ -224,8 +230,19 @@ def register_graph_buffers(self):
224230
ops.register_graph_buffers(self._ptr, handles, offsets)
225231

226232
def should_custom_ar(self, inp: torch.Tensor):
227-
return ops.should_custom_ar(inp, self.max_size, self.world_size,
228-
self.full_nvlink)
233+
if self.disabled:
234+
return False
235+
inp_size = inp.numel() * inp.element_size()
236+
# custom allreduce requires input byte size to be multiples of 16
237+
if inp_size % 16 != 0:
238+
return False
239+
if not is_weak_contiguous(inp):
240+
return False
241+
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
242+
# little performance improvement over NCCL.
243+
if self.world_size == 2 or self.full_nvlink:
244+
return inp_size < self.max_size
245+
return False
229246

230247
# all reduce, assuming inp tensor is IPC registered with register_buffer,
231248
# or, in the context of cuda graphs, register_graph_buffers

vllm/distributed/parallel_state.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
"""
2222
import contextlib
2323
import pickle
24+
import weakref
2425
from collections import namedtuple
2526
from contextlib import contextmanager, nullcontext
2627
from dataclasses import dataclass
2728
from multiprocessing import shared_memory
28-
from typing import Any, Dict, List, Optional, Tuple, Union
29+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2930
from unittest.mock import patch
3031

3132
import torch
@@ -69,6 +70,58 @@ def _split_tensor_dict(
6970
return metadata_list, tensor_list
7071

7172

73+
_group_name_counter: Dict[str, int] = {}
74+
75+
76+
def _get_unique_name(name: str) -> str:
77+
"""Get a unique name for the group.
78+
Example:
79+
_get_unique_name("tp") -> "tp:0"
80+
_get_unique_name("tp") -> "tp:1"
81+
"""
82+
if name not in _group_name_counter:
83+
_group_name_counter[name] = 0
84+
newname = f"{name}:{_group_name_counter[name]}"
85+
_group_name_counter[name] += 1
86+
return newname
87+
88+
89+
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
90+
91+
92+
def _register_group(group: "GroupCoordinator") -> None:
93+
# looks like Python 3.8 does not understand `ReferenceType`
94+
_groups[group.unique_name] = weakref.ref(group) # type: ignore
95+
96+
97+
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
98+
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
99+
assert group_name in _groups, f"Group {group_name} is not found."
100+
group = _groups[group_name]()
101+
if group is None:
102+
raise ValueError(f"Group {group_name} is destroyed.")
103+
group._all_reduce(tensor)
104+
105+
106+
@inplace_all_reduce.register_fake
107+
def _(tensor: torch.Tensor, group_name: str) -> None:
108+
return
109+
110+
111+
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
112+
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
113+
assert group_name in _groups, f"Group {group_name} is not found."
114+
group = _groups[group_name]()
115+
if group is None:
116+
raise ValueError(f"Group {group_name} is destroyed.")
117+
return group._all_reduce(tensor)
118+
119+
120+
@outplace_all_reduce.register_fake
121+
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
122+
return torch.empty_like(tensor)
123+
124+
72125
class GroupCoordinator:
73126
"""
74127
PyTorch ProcessGroup wrapper for a group of processes.
@@ -111,7 +164,11 @@ def __init__(
111164
use_custom_allreduce: bool,
112165
use_tpu_communicator: bool,
113166
use_message_queue_broadcaster: bool = False,
167+
group_name: Optional[str] = None,
114168
):
169+
group_name = group_name or "anonymous"
170+
self.unique_name = _get_unique_name(group_name)
171+
_register_group(self)
115172

116173
self.rank = torch.distributed.get_rank()
117174
self.local_rank = local_rank
@@ -149,28 +206,24 @@ def __init__(
149206
from vllm.distributed.device_communicators.pynccl import (
150207
PyNcclCommunicator)
151208

152-
self.pynccl_comm: Optional[PyNcclCommunicator]
209+
self.pynccl_comm: Optional[PyNcclCommunicator] = None
153210
if use_pynccl and self.world_size > 1:
154211
self.pynccl_comm = PyNcclCommunicator(
155212
group=self.cpu_group,
156213
device=self.device,
157214
)
158-
else:
159-
self.pynccl_comm = None
160215

161-
self.ca_comm: Optional[CustomAllreduce]
216+
self.ca_comm: Optional[CustomAllreduce] = None
162217
if use_custom_allreduce and self.world_size > 1:
163218
# Initialize a custom fast all-reduce implementation.
164219
self.ca_comm = CustomAllreduce(
165220
group=self.cpu_group,
166221
device=self.device,
167222
)
168-
else:
169-
self.ca_comm = None
170223

171224
from vllm.distributed.device_communicators.tpu_communicator import (
172225
TpuCommunicator)
173-
self.tpu_communicator: Optional[TpuCommunicator]
226+
self.tpu_communicator: Optional[TpuCommunicator] = None
174227
if use_tpu_communicator and self.world_size > 1:
175228
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
176229

@@ -264,16 +317,46 @@ def graph_capture(
264317

265318
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
266319
"""
320+
User-facing all-reduce function before we actually call the
321+
all-reduce operation.
322+
323+
We need this because Dynamo does not support passing an arbitrary
324+
object (`self` in this case) to a custom op. We need to pass the
325+
group name as a string, and then look up the group coordinator from
326+
the group name, dispatch the all-reduce operation to the group
327+
coordinator.
328+
329+
In addition, PyTorch custom ops do not support mutation or returning
330+
a new tensor in the same op. So we need to figure out if the op is
331+
in-place or out-of-place ahead of time.
332+
"""
333+
# Bypass the function if we are using only 1 GPU.
334+
if self.world_size == 1:
335+
return input_
336+
337+
if self.tpu_communicator is not None and \
338+
not self.tpu_communicator.disabled:
339+
# TPU handles Dynamo with its own logic.
340+
return self._all_reduce(input_)
341+
342+
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
343+
return torch.ops.vllm.outplace_all_reduce(
344+
input_, group_name=self.unique_name)
345+
else:
346+
torch.ops.vllm.inplace_all_reduce(input_,
347+
group_name=self.unique_name)
348+
return input_
349+
350+
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
351+
"""
352+
The actual all-reduce implementation.
353+
267354
NOTE: This operation will be applied in-place or out-of-place.
268355
Always assume this function modifies its input, but use the return
269356
value as the output.
270357
"""
271358
ca_comm = self.ca_comm
272359

273-
# Bypass the function if we are using only 1 GPU.
274-
if self.world_size == 1:
275-
return input_
276-
277360
# For TPUs, use TPU communicator.
278361
tpu_comm = self.tpu_communicator
279362
if tpu_comm is not None and not tpu_comm.disabled:
@@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
758841
use_pynccl=False,
759842
use_custom_allreduce=False,
760843
use_tpu_communicator=False,
844+
group_name="world",
761845
)
762846

763847

@@ -767,6 +851,7 @@ def init_model_parallel_group(
767851
backend: str,
768852
use_custom_allreduce: Optional[bool] = None,
769853
use_message_queue_broadcaster: bool = False,
854+
group_name: Optional[str] = None,
770855
) -> GroupCoordinator:
771856
if use_custom_allreduce is None:
772857
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
@@ -778,6 +863,7 @@ def init_model_parallel_group(
778863
use_custom_allreduce=use_custom_allreduce,
779864
use_tpu_communicator=True,
780865
use_message_queue_broadcaster=use_message_queue_broadcaster,
866+
group_name=group_name,
781867
)
782868

783869

@@ -931,7 +1017,8 @@ def initialize_model_parallel(
9311017
_TP = init_model_parallel_group(group_ranks,
9321018
get_world_group().local_rank,
9331019
backend,
934-
use_message_queue_broadcaster=True)
1020+
use_message_queue_broadcaster=True,
1021+
group_name="tp")
9351022

9361023
# Build the pipeline model-parallel groups.
9371024
num_pipeline_model_parallel_groups: int = (world_size //
@@ -947,7 +1034,8 @@ def initialize_model_parallel(
9471034
_PP = init_model_parallel_group(group_ranks,
9481035
get_world_group().local_rank,
9491036
backend,
950-
use_custom_allreduce=False)
1037+
use_custom_allreduce=False,
1038+
group_name="pp")
9511039

9521040

9531041
def ensure_model_parallel_initialized(

0 commit comments

Comments
 (0)