21
21
"""
22
22
import contextlib
23
23
import pickle
24
+ import weakref
24
25
from collections import namedtuple
25
26
from contextlib import contextmanager , nullcontext
26
27
from dataclasses import dataclass
27
28
from multiprocessing import shared_memory
28
- from typing import Any , Dict , List , Optional , Tuple , Union
29
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
29
30
from unittest .mock import patch
30
31
31
32
import torch
@@ -69,6 +70,58 @@ def _split_tensor_dict(
69
70
return metadata_list , tensor_list
70
71
71
72
73
+ _group_name_counter : Dict [str , int ] = {}
74
+
75
+
76
+ def _get_unique_name (name : str ) -> str :
77
+ """Get a unique name for the group.
78
+ Example:
79
+ _get_unique_name("tp") -> "tp:0"
80
+ _get_unique_name("tp") -> "tp:1"
81
+ """
82
+ if name not in _group_name_counter :
83
+ _group_name_counter [name ] = 0
84
+ newname = f"{ name } :{ _group_name_counter [name ]} "
85
+ _group_name_counter [name ] += 1
86
+ return newname
87
+
88
+
89
+ _groups : Dict [str , Callable [[], "GroupCoordinator" ]] = {}
90
+
91
+
92
+ def _register_group (group : "GroupCoordinator" ) -> None :
93
+ # looks like Python 3.8 does not understand `ReferenceType`
94
+ _groups [group .unique_name ] = weakref .ref (group ) # type: ignore
95
+
96
+
97
+ @torch .library .custom_op ("vllm::inplace_all_reduce" , mutates_args = ["tensor" ])
98
+ def inplace_all_reduce (tensor : torch .Tensor , group_name : str ) -> None :
99
+ assert group_name in _groups , f"Group { group_name } is not found."
100
+ group = _groups [group_name ]()
101
+ if group is None :
102
+ raise ValueError (f"Group { group_name } is destroyed." )
103
+ group ._all_reduce (tensor )
104
+
105
+
106
+ @inplace_all_reduce .register_fake
107
+ def _ (tensor : torch .Tensor , group_name : str ) -> None :
108
+ return
109
+
110
+
111
+ @torch .library .custom_op ("vllm::outplace_all_reduce" , mutates_args = [])
112
+ def outplace_all_reduce (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
113
+ assert group_name in _groups , f"Group { group_name } is not found."
114
+ group = _groups [group_name ]()
115
+ if group is None :
116
+ raise ValueError (f"Group { group_name } is destroyed." )
117
+ return group ._all_reduce (tensor )
118
+
119
+
120
+ @outplace_all_reduce .register_fake
121
+ def _ (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
122
+ return torch .empty_like (tensor )
123
+
124
+
72
125
class GroupCoordinator :
73
126
"""
74
127
PyTorch ProcessGroup wrapper for a group of processes.
@@ -111,7 +164,11 @@ def __init__(
111
164
use_custom_allreduce : bool ,
112
165
use_tpu_communicator : bool ,
113
166
use_message_queue_broadcaster : bool = False ,
167
+ group_name : Optional [str ] = None ,
114
168
):
169
+ group_name = group_name or "anonymous"
170
+ self .unique_name = _get_unique_name (group_name )
171
+ _register_group (self )
115
172
116
173
self .rank = torch .distributed .get_rank ()
117
174
self .local_rank = local_rank
@@ -149,28 +206,24 @@ def __init__(
149
206
from vllm .distributed .device_communicators .pynccl import (
150
207
PyNcclCommunicator )
151
208
152
- self .pynccl_comm : Optional [PyNcclCommunicator ]
209
+ self .pynccl_comm : Optional [PyNcclCommunicator ] = None
153
210
if use_pynccl and self .world_size > 1 :
154
211
self .pynccl_comm = PyNcclCommunicator (
155
212
group = self .cpu_group ,
156
213
device = self .device ,
157
214
)
158
- else :
159
- self .pynccl_comm = None
160
215
161
- self .ca_comm : Optional [CustomAllreduce ]
216
+ self .ca_comm : Optional [CustomAllreduce ] = None
162
217
if use_custom_allreduce and self .world_size > 1 :
163
218
# Initialize a custom fast all-reduce implementation.
164
219
self .ca_comm = CustomAllreduce (
165
220
group = self .cpu_group ,
166
221
device = self .device ,
167
222
)
168
- else :
169
- self .ca_comm = None
170
223
171
224
from vllm .distributed .device_communicators .tpu_communicator import (
172
225
TpuCommunicator )
173
- self .tpu_communicator : Optional [TpuCommunicator ]
226
+ self .tpu_communicator : Optional [TpuCommunicator ] = None
174
227
if use_tpu_communicator and self .world_size > 1 :
175
228
self .tpu_communicator = TpuCommunicator (group = self .cpu_group )
176
229
@@ -264,16 +317,46 @@ def graph_capture(
264
317
265
318
def all_reduce (self , input_ : torch .Tensor ) -> torch .Tensor :
266
319
"""
320
+ User-facing all-reduce function before we actually call the
321
+ all-reduce operation.
322
+
323
+ We need this because Dynamo does not support passing an arbitrary
324
+ object (`self` in this case) to a custom op. We need to pass the
325
+ group name as a string, and then look up the group coordinator from
326
+ the group name, dispatch the all-reduce operation to the group
327
+ coordinator.
328
+
329
+ In addition, PyTorch custom ops do not support mutation or returning
330
+ a new tensor in the same op. So we need to figure out if the op is
331
+ in-place or out-of-place ahead of time.
332
+ """
333
+ # Bypass the function if we are using only 1 GPU.
334
+ if self .world_size == 1 :
335
+ return input_
336
+
337
+ if self .tpu_communicator is not None and \
338
+ not self .tpu_communicator .disabled :
339
+ # TPU handles Dynamo with its own logic.
340
+ return self ._all_reduce (input_ )
341
+
342
+ if self .ca_comm is not None and self .ca_comm .should_custom_ar (input_ ):
343
+ return torch .ops .vllm .outplace_all_reduce (
344
+ input_ , group_name = self .unique_name )
345
+ else :
346
+ torch .ops .vllm .inplace_all_reduce (input_ ,
347
+ group_name = self .unique_name )
348
+ return input_
349
+
350
+ def _all_reduce (self , input_ : torch .Tensor ) -> torch .Tensor :
351
+ """
352
+ The actual all-reduce implementation.
353
+
267
354
NOTE: This operation will be applied in-place or out-of-place.
268
355
Always assume this function modifies its input, but use the return
269
356
value as the output.
270
357
"""
271
358
ca_comm = self .ca_comm
272
359
273
- # Bypass the function if we are using only 1 GPU.
274
- if self .world_size == 1 :
275
- return input_
276
-
277
360
# For TPUs, use TPU communicator.
278
361
tpu_comm = self .tpu_communicator
279
362
if tpu_comm is not None and not tpu_comm .disabled :
@@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
758
841
use_pynccl = False ,
759
842
use_custom_allreduce = False ,
760
843
use_tpu_communicator = False ,
844
+ group_name = "world" ,
761
845
)
762
846
763
847
@@ -767,6 +851,7 @@ def init_model_parallel_group(
767
851
backend : str ,
768
852
use_custom_allreduce : Optional [bool ] = None ,
769
853
use_message_queue_broadcaster : bool = False ,
854
+ group_name : Optional [str ] = None ,
770
855
) -> GroupCoordinator :
771
856
if use_custom_allreduce is None :
772
857
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
@@ -778,6 +863,7 @@ def init_model_parallel_group(
778
863
use_custom_allreduce = use_custom_allreduce ,
779
864
use_tpu_communicator = True ,
780
865
use_message_queue_broadcaster = use_message_queue_broadcaster ,
866
+ group_name = group_name ,
781
867
)
782
868
783
869
@@ -931,7 +1017,8 @@ def initialize_model_parallel(
931
1017
_TP = init_model_parallel_group (group_ranks ,
932
1018
get_world_group ().local_rank ,
933
1019
backend ,
934
- use_message_queue_broadcaster = True )
1020
+ use_message_queue_broadcaster = True ,
1021
+ group_name = "tp" )
935
1022
936
1023
# Build the pipeline model-parallel groups.
937
1024
num_pipeline_model_parallel_groups : int = (world_size //
@@ -947,7 +1034,8 @@ def initialize_model_parallel(
947
1034
_PP = init_model_parallel_group (group_ranks ,
948
1035
get_world_group ().local_rank ,
949
1036
backend ,
950
- use_custom_allreduce = False )
1037
+ use_custom_allreduce = False ,
1038
+ group_name = "pp" )
951
1039
952
1040
953
1041
def ensure_model_parallel_initialized (
0 commit comments