6
6
7
7
8
8
from functools import partial
9
- from typing import Optional , Tuple
9
+ from typing import Callable
10
10
11
11
import torch
12
+ import torch .distributed as dist
12
13
import torch .nn as nn
14
+ from torch .distributed ._functional_collectives import all_to_all_single_autograd
13
15
from torch .distributed .tensor import (
14
16
DeviceMesh ,
15
17
distribute_module ,
@@ -27,8 +29,8 @@ class TensorParallel(ParallelStyle):
27
29
def __init__ (
28
30
self ,
29
31
* ,
30
- input_layouts : Optional [ Tuple [ Optional [ Placement ]]] = None ,
31
- output_layout : Optional [ Placement ] = None ,
32
+ input_layouts : tuple [ Placement | None ] | None = None ,
33
+ output_layout : Placement | None = None ,
32
34
use_local_output : bool = True ,
33
35
):
34
36
super ().__init__ ()
@@ -99,8 +101,8 @@ class NoParallel(ParallelStyle):
99
101
def __init__ (
100
102
self ,
101
103
* ,
102
- input_layout : Optional [ Placement ] = None ,
103
- output_layout : Optional [ Placement ] = None ,
104
+ input_layout : Placement | None = None ,
105
+ output_layout : Placement | None = None ,
104
106
use_local_output : bool = True ,
105
107
):
106
108
super ().__init__ ()
@@ -141,3 +143,143 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
141
143
),
142
144
partial (self ._prepare_output_fn , self .output_layout , self .use_local_output ),
143
145
)
146
+
147
+
148
+ class ExpertParallel (ParallelStyle ):
149
+ def __init__ (
150
+ self ,
151
+ * ,
152
+ input_layouts : Placement | None = None ,
153
+ output_layouts : Placement | None = None ,
154
+ use_local_output : bool = True ,
155
+ ):
156
+ super ().__init__ ()
157
+ self .input_layouts = (input_layouts or Shard (0 ),)
158
+ self .output_layouts = (output_layouts or Shard (0 ),)
159
+ self .use_local_output = use_local_output
160
+ self .input_splits = None
161
+ self .output_splits = None
162
+
163
+ # performing all-to-all dispatch on the input
164
+ def _prepare_input_fn (self , mod , inputs , device_mesh ):
165
+ # annotate module input placements/sharding with input_layouts
166
+ routed_input , num_tokens_per_expert = inputs
167
+
168
+ # generate the input splits and output splits for all-to-all
169
+ with torch .no_grad ():
170
+ num_tokens_per_expert_group = num_tokens_per_expert .new_empty (
171
+ num_tokens_per_expert .shape [0 ]
172
+ )
173
+ dist .all_to_all_single (
174
+ num_tokens_per_expert_group ,
175
+ num_tokens_per_expert ,
176
+ group = device_mesh .get_group (),
177
+ )
178
+ # NOTE: this would incur a device-to-host sync
179
+ self .input_splits = (
180
+ num_tokens_per_expert .view (device_mesh .shape [0 ], - 1 ).sum (dim = 1 ).tolist ()
181
+ )
182
+ self .output_splits = (
183
+ num_tokens_per_expert_group .view (device_mesh .shape [0 ], - 1 )
184
+ .sum (dim = 1 )
185
+ .tolist ()
186
+ )
187
+
188
+ # perform all-to-all
189
+ routed_input = all_to_all_single_autograd (
190
+ routed_input ,
191
+ self .output_splits ,
192
+ self .input_splits ,
193
+ device_mesh .get_group (),
194
+ )
195
+
196
+ # NOTE: After this all-to-all, the routed input is put on proper EP rank.
197
+ # However, the num_tokens_per_expert_group is not of the final target format
198
+ # [#tokens for local expert 0, #tokens for local expert 1, ...]
199
+ # Rather, it is of the format
200
+ # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
201
+ # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
202
+ # We need to perform another shuffle to get the correct format -- this is done via the function
203
+ # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
204
+ # each expert gets locally is a multiple of ALIGN_SIZE_M.
205
+
206
+ return routed_input , num_tokens_per_expert_group
207
+
208
+ def _partition_fn (self , name , module , device_mesh ):
209
+ # shard on the expert dimension
210
+ for name , param in module .named_parameters (recurse = False ):
211
+ dist_param = nn .Parameter (distribute_tensor (param , device_mesh , [Shard (0 )]))
212
+ module .register_parameter (name , dist_param )
213
+
214
+ # performing all-to-all combine on the output
215
+ def _prepare_output_fn (self , mod , routed_output , device_mesh ):
216
+ routed_output = all_to_all_single_autograd (
217
+ routed_output ,
218
+ self .input_splits ,
219
+ self .output_splits ,
220
+ device_mesh .get_group (),
221
+ )
222
+ return routed_output
223
+
224
+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
225
+ return distribute_module (
226
+ module ,
227
+ device_mesh ,
228
+ self ._partition_fn ,
229
+ self ._prepare_input_fn ,
230
+ self ._prepare_output_fn ,
231
+ )
232
+
233
+
234
+ def expert_parallel (func : Callable ) -> Callable :
235
+ def wrapper (
236
+ w1 : torch .Tensor ,
237
+ w2 : torch .Tensor ,
238
+ w3 : torch .Tensor ,
239
+ x : torch .Tensor ,
240
+ num_tokens_per_expert : torch .Tensor | None = None ,
241
+ ) -> torch .Tensor :
242
+ if isinstance (w1 , DTensor ):
243
+ w1 = w1 .to_local ()
244
+ w2 = w2 .to_local ()
245
+ w3 = w3 .to_local ()
246
+
247
+ if num_tokens_per_expert is not None :
248
+ # NOTE: In order to use torch._grouped_mm, we need to make sure
249
+ # the number of tokens each expert gets is a multiple of 16.
250
+ # The following kernel helps achieve this via padding, without
251
+ # incurring synchronization between device and host.
252
+ from torchtitan .experiments .kernels .moe .indices import (
253
+ generate_permute_indices ,
254
+ )
255
+
256
+ experts_per_ep_rank = w1 .shape [0 ]
257
+ num_ep_ranks = num_tokens_per_expert .shape [0 ] // experts_per_ep_rank
258
+
259
+ ALIGN_SIZE_M = 16
260
+ with torch .no_grad ():
261
+ (
262
+ permuted_indices ,
263
+ num_tokens_per_expert ,
264
+ _ , # offsets,
265
+ ) = generate_permute_indices (
266
+ num_tokens_per_expert ,
267
+ experts_per_ep_rank ,
268
+ num_ep_ranks ,
269
+ ALIGN_SIZE_M ,
270
+ )
271
+
272
+ x = torch .vstack ((x , x .new_zeros ((x .shape [- 1 ]))))
273
+ input_shape = x .shape
274
+ x = x [permuted_indices , :]
275
+
276
+ out = func (w1 , w2 , w3 , x , num_tokens_per_expert )
277
+
278
+ if num_tokens_per_expert is not None :
279
+ out_unpermuted = out .new_empty (input_shape )
280
+ out_unpermuted [permuted_indices , :] = out
281
+ out = out_unpermuted [:- 1 ]
282
+
283
+ return out
284
+
285
+ return wrapper
0 commit comments