Skip to content

Commit 76f1931

Browse files
committed
[WIP] expert parallel dp2ep
1 parent f4048f8 commit 76f1931

File tree

11 files changed

+483
-158
lines changed

11 files changed

+483
-158
lines changed

run_train.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ set -ex
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
14-
export LOG_RANK=${LOG_RANK:-0}
15-
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
14+
export LOG_RANK=${LOG_RANK:-0,1}
15+
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/llama4/train_configs/debug_model.toml"}
1616

1717
overrides=""
1818
if [ $# -ne 0 ]; then

torchtitan/components/checkpoint.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
LR_SCHEDULER = "lr_scheduler"
4242
DATALOADER = "dataloader"
4343
TRAIN_STATE = "train_state"
44+
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
45+
# temporarily and we don't want to include it in the exported state_dict.
46+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
47+
excluded_parameters_for_model_only = {"freqs_cis"}
4448

4549

4650
class AsyncMode(str, enum.Enum):
@@ -53,7 +57,10 @@ class ModelWrapper(Stateful):
5357
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
5458
self.model = [model] if isinstance(model, nn.Module) else model
5559
self.cache_state_dict = {
56-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
60+
k: v
61+
for sd in map(get_model_state_dict, self.model)
62+
for k, v in sd.items()
63+
if k not in excluded_parameters_for_model_only
5764
}
5865

5966
def state_dict(self) -> dict[str, Any]:
@@ -69,7 +76,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
6976
# `set_model_state_dict()` does change the keys of the input state_dict,
7077
# we will need to reinitialize the cache_state_dict.
7178
self.cache_state_dict = {
72-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
79+
k: v
80+
for sd in map(get_model_state_dict, self.model)
81+
for k, v in sd.items()
82+
if k not in excluded_parameters_for_model_only
7383
}
7484

7585

@@ -81,12 +91,6 @@ class SaveDone:
8191
pass
8292

8393

84-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
85-
# temporarily and we don't want to include it in the exported state_dict.
86-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
87-
excluded_parameters_for_model_only = {"freqs_cis"}
88-
89-
9094
@torch.no_grad()
9195
def save_with_gc(state, checkpoint_id):
9296
dcp.save(state, checkpoint_id=checkpoint_id)
@@ -568,10 +572,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
568572
"""
569573
# For the first step, we will only load the model weights.
570574
if model_only:
571-
sd = self.states[MODEL].state_dict()
572-
for k in excluded_parameters_for_model_only:
573-
sd.pop(k, None)
574-
return sd
575+
return {MODEL: self.states[MODEL]}
575576

576577
for exclude_key in self.exclude_from_loading:
577578
if exclude_key not in self.states:

torchtitan/config_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,13 @@ class Parallelism:
363363
The default value is 'allgather'.
364364
"""
365365

366+
expert_parallel_degree: int = 1
367+
"""
368+
Expert parallelism degree. 1 means disabled.
369+
Currently, only "dp2ep" is supported.
370+
EP degree has to be k * context_parallel_degree, where k >= 1 and data_parallel_shard_degree % k == 0.
371+
"""
372+
366373

367374
@dataclass
368375
class Checkpoint:

torchtitan/distributed/parallel_dims.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,23 @@ class ParallelDims:
2323
cp: int
2424
tp: int
2525
pp: int
26+
ep: int
2627
world_size: int
2728
enable_loss_parallel: bool
2829

2930
def __post_init__(self):
3031
self._validate()
3132

3233
def _validate(self):
33-
dp_replicate, dp_shard, cp, tp, pp = (
34+
dp_replicate, dp_shard, cp, tp, pp, ep = (
3435
self.dp_replicate,
3536
self.dp_shard,
3637
self.cp,
3738
self.tp,
3839
self.pp,
40+
self.ep,
3941
)
40-
for d in (dp_replicate, cp, tp, pp):
42+
for d in (dp_replicate, cp, tp, pp, ep):
4143
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
4244

4345
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -50,7 +52,78 @@ def _validate(self):
5052
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
5153
)
5254

55+
if ep > 1:
56+
# EP would borrow all cp and some dp_shard degree
57+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
59+
def _build_mesh_with_ep(self, device_type):
60+
# With ep, dp_shard and ep are derived submeshes:
61+
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
62+
# ep = dp_shard_in_ep * cp
63+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
64+
dp_shard_in_ep = self.ep // self.cp
65+
66+
dims = []
67+
names = []
68+
for d, name in zip(
69+
[
70+
self.pp,
71+
self.dp_replicate,
72+
dp_shard_mod_ep,
73+
dp_shard_in_ep,
74+
self.cp,
75+
self.tp,
76+
],
77+
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
78+
):
79+
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
80+
# helps the MoE layers do mixed precision training
81+
if d > 1 or name == "dp_shard_mod_ep":
82+
dims.append(d)
83+
names.append(name)
84+
85+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
86+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
87+
88+
# Create all the submesh here to ensure all required process groups are
89+
# initialized:
90+
# Mesh for data loading (no communication on this mesh)
91+
dp_mesh_dim_names = []
92+
# Mesh for param sharding
93+
dp_shard_cp_mesh_dim_names = []
94+
# Mesh for loss all-reduce
95+
dp_cp_mesh_dim_names = []
96+
# Mesh for ep
97+
ep_mesh_dim_names = []
98+
99+
if self.dp_replicate_enabled:
100+
dp_mesh_dim_names.append("dp_replicate")
101+
dp_cp_mesh_dim_names.append("dp_replicate")
102+
# dp_shard_mod_ep is always needed, even if it's 1
103+
dp_mesh_dim_names.append("dp_shard_mod_ep")
104+
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
105+
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
106+
if "dp_shard_in_ep" in names:
107+
dp_mesh_dim_names.append("dp_shard_in_ep")
108+
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
109+
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
110+
ep_mesh_dim_names.append("dp_shard_in_ep")
111+
if self.cp_enabled:
112+
dp_shard_cp_mesh_dim_names.append("cp")
113+
dp_cp_mesh_dim_names.append("cp")
114+
ep_mesh_dim_names.append("cp")
115+
116+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
117+
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
118+
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
119+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
120+
121+
return mesh
122+
53123
def build_mesh(self, device_type: str) -> DeviceMesh:
124+
if self.ep > 1:
125+
return self._build_mesh_with_ep(device_type)
126+
54127
dims = []
55128
names = []
56129
for d, name in zip(
@@ -143,3 +216,7 @@ def loss_parallel_enabled(self):
143216
@cached_property
144217
def non_data_parallel_size(self):
145218
return self.cp * self.tp * self.pp
219+
220+
@property
221+
def ep_enabled(self):
222+
return self.ep > 1

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77

88
from functools import partial
9-
from typing import Optional, Tuple
9+
from typing import Callable
1010

1111
import torch
12+
import torch.distributed as dist
1213
import torch.nn as nn
14+
from torch.distributed._functional_collectives import all_to_all_single_autograd
1315
from torch.distributed.tensor import (
1416
DeviceMesh,
1517
distribute_module,
@@ -27,8 +29,8 @@ class TensorParallel(ParallelStyle):
2729
def __init__(
2830
self,
2931
*,
30-
input_layouts: Optional[Tuple[Optional[Placement]]] = None,
31-
output_layout: Optional[Placement] = None,
32+
input_layouts: tuple[Placement | None] | None = None,
33+
output_layout: Placement | None = None,
3234
use_local_output: bool = True,
3335
):
3436
super().__init__()
@@ -99,8 +101,8 @@ class NoParallel(ParallelStyle):
99101
def __init__(
100102
self,
101103
*,
102-
input_layout: Optional[Placement] = None,
103-
output_layout: Optional[Placement] = None,
104+
input_layout: Placement | None = None,
105+
output_layout: Placement | None = None,
104106
use_local_output: bool = True,
105107
):
106108
super().__init__()
@@ -141,3 +143,143 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
141143
),
142144
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
143145
)
146+
147+
148+
class ExpertParallel(ParallelStyle):
149+
def __init__(
150+
self,
151+
*,
152+
input_layouts: Placement | None = None,
153+
output_layouts: Placement | None = None,
154+
use_local_output: bool = True,
155+
):
156+
super().__init__()
157+
self.input_layouts = (input_layouts or Shard(0),)
158+
self.output_layouts = (output_layouts or Shard(0),)
159+
self.use_local_output = use_local_output
160+
self.input_splits = None
161+
self.output_splits = None
162+
163+
# performing all-to-all dispatch on the input
164+
def _prepare_input_fn(self, mod, inputs, device_mesh):
165+
# annotate module input placements/sharding with input_layouts
166+
routed_input, num_tokens_per_expert = inputs
167+
168+
# generate the input splits and output splits for all-to-all
169+
with torch.no_grad():
170+
num_tokens_per_expert_group = num_tokens_per_expert.new_empty(
171+
num_tokens_per_expert.shape[0]
172+
)
173+
dist.all_to_all_single(
174+
num_tokens_per_expert_group,
175+
num_tokens_per_expert,
176+
group=device_mesh.get_group(),
177+
)
178+
# NOTE: this would incur a device-to-host sync
179+
self.input_splits = (
180+
num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist()
181+
)
182+
self.output_splits = (
183+
num_tokens_per_expert_group.view(device_mesh.shape[0], -1)
184+
.sum(dim=1)
185+
.tolist()
186+
)
187+
188+
# perform all-to-all
189+
routed_input = all_to_all_single_autograd(
190+
routed_input,
191+
self.output_splits,
192+
self.input_splits,
193+
device_mesh.get_group(),
194+
)
195+
196+
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
197+
# However, the num_tokens_per_expert_group is not of the final target format
198+
# [#tokens for local expert 0, #tokens for local expert 1, ...]
199+
# Rather, it is of the format
200+
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
201+
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
202+
# We need to perform another shuffle to get the correct format -- this is done via the function
203+
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
204+
# each expert gets locally is a multiple of ALIGN_SIZE_M.
205+
206+
return routed_input, num_tokens_per_expert_group
207+
208+
def _partition_fn(self, name, module, device_mesh):
209+
# shard on the expert dimension
210+
for name, param in module.named_parameters(recurse=False):
211+
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
212+
module.register_parameter(name, dist_param)
213+
214+
# performing all-to-all combine on the output
215+
def _prepare_output_fn(self, mod, routed_output, device_mesh):
216+
routed_output = all_to_all_single_autograd(
217+
routed_output,
218+
self.input_splits,
219+
self.output_splits,
220+
device_mesh.get_group(),
221+
)
222+
return routed_output
223+
224+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
225+
return distribute_module(
226+
module,
227+
device_mesh,
228+
self._partition_fn,
229+
self._prepare_input_fn,
230+
self._prepare_output_fn,
231+
)
232+
233+
234+
def expert_parallel(func: Callable) -> Callable:
235+
def wrapper(
236+
w1: torch.Tensor,
237+
w2: torch.Tensor,
238+
w3: torch.Tensor,
239+
x: torch.Tensor,
240+
num_tokens_per_expert: torch.Tensor | None = None,
241+
) -> torch.Tensor:
242+
if isinstance(w1, DTensor):
243+
w1 = w1.to_local()
244+
w2 = w2.to_local()
245+
w3 = w3.to_local()
246+
247+
if num_tokens_per_expert is not None:
248+
# NOTE: In order to use torch._grouped_mm, we need to make sure
249+
# the number of tokens each expert gets is a multiple of 16.
250+
# The following kernel helps achieve this via padding, without
251+
# incurring synchronization between device and host.
252+
from torchtitan.experiments.kernels.moe.indices import (
253+
generate_permute_indices,
254+
)
255+
256+
experts_per_ep_rank = w1.shape[0]
257+
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
258+
259+
ALIGN_SIZE_M = 16
260+
with torch.no_grad():
261+
(
262+
permuted_indices,
263+
num_tokens_per_expert,
264+
_, # offsets,
265+
) = generate_permute_indices(
266+
num_tokens_per_expert,
267+
experts_per_ep_rank,
268+
num_ep_ranks,
269+
ALIGN_SIZE_M,
270+
)
271+
272+
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
273+
input_shape = x.shape
274+
x = x[permuted_indices, :]
275+
276+
out = func(w1, w2, w3, x, num_tokens_per_expert)
277+
278+
if num_tokens_per_expert is not None:
279+
out_unpermuted = out.new_empty(input_shape)
280+
out_unpermuted[permuted_indices, :] = out
281+
out = out_unpermuted[:-1]
282+
283+
return out
284+
285+
return wrapper

0 commit comments

Comments
 (0)