Skip to content

[llama4] enable expert parallel on the same device mesh as tp (tp2ep) #1269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build
outputs
dist/*
.vscode
slurm-*.out

# data
data
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ class Parallelism:
- 'alltoall' means to all-to-all shuffle the kv shards.
The default value is 'allgather'.
"""

enable_tp2ep: bool = False
"""Whether to use expert parallelism instead of tensor parallelism for shared experts."""


@dataclass
Expand Down
128 changes: 128 additions & 0 deletions torchtitan/experiments/kernels/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Tuple
import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import all_to_all_single_autograd


class DefaultTokenDispatcher:

def __init__(self, num_experts: int, ep_size: int = 1):
self.num_experts = num_experts
self.ep_size = ep_size
self.experts_per_rank = num_experts // ep_size
self.ep_group = None

def token_permutation(
self,
routed_input: torch.Tensor,
top_scores: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
training: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None,
torch.Tensor | None]:
return routed_input, top_scores, num_local_tokens_per_expert, None, None

def token_unpermutation(
self,
routed_output: torch.Tensor,
input_splits: torch.Tensor | None = None,
output_splits: torch.Tensor | None = None,
training: bool = True,
) -> torch.Tensor:
return routed_output


class TorchAllToAllTokenDispatcher(DefaultTokenDispatcher):

def __init__(
self,
num_experts: int,
ep_size: int,
ep_group: torch.distributed.ProcessGroup,
):
super().__init__(num_experts, ep_size)
self.ep_group = ep_group

def token_permutation(
self,
routed_input: torch.Tensor,
top_scores: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
training: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None,
torch.Tensor | None]:
dim = routed_input.shape[-1]
with torch.no_grad():
tokens_per_expert_group = num_local_tokens_per_expert.new_empty(
num_local_tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group,
num_local_tokens_per_expert,
group=self.ep_group)
input_splits = num_local_tokens_per_expert.view(
self.ep_size, -1).sum(dim=1)
output_splits = tokens_per_expert_group.view(
self.ep_size, -1).sum(dim=1)
if training:
gathered_tokens = all_to_all_single_autograd(
routed_input,
output_splits.tolist(),
input_splits.tolist(),
self.ep_group,
)
gathered_top_scores = all_to_all_single_autograd(
top_scores,
output_splits.tolist(),
input_splits.tolist(),
self.ep_group,
)
else:
# TODO: unify with all_to_all_single_autograd after
# https://github.com/pytorch/pytorch/issues/154370 is resolved
gathered_num_tokens = output_splits.sum()
gathered_tokens = routed_input.new_empty(
(gathered_num_tokens, dim))
dist.all_to_all_single(
gathered_tokens,
routed_input,
output_splits.tolist(),
input_splits.tolist(),
group=self.ep_group,
)
gathered_top_scores = top_scores.new_empty(gathered_num_tokens, )
dist.all_to_all_single(
gathered_top_scores,
top_scores,
output_splits.tolist(),
input_splits.tolist(),
group=self.ep_group,
)
return gathered_tokens, gathered_top_scores, tokens_per_expert_group, input_splits, output_splits

def token_unpermutation(
self,
routed_output: torch.Tensor,
input_splits: torch.Tensor | None = None,
output_splits: torch.Tensor | None = None,
training: bool = True,
) -> torch.Tensor:
dim = routed_output.shape[-1]
if training:
returned_tokens = all_to_all_single_autograd(
routed_output,
input_splits.tolist(),
output_splits.tolist(),
self.ep_group,
)
else:
# TODO: unify with all_to_all_single_autograd after
# https://github.com/pytorch/pytorch/issues/154370 is resolved
returned_tokens = routed_output.new_empty(
(input_splits.sum(), dim))
dist.all_to_all_single(
returned_tokens,
routed_output,
input_splits.tolist(),
output_splits.tolist(),
group=self.ep_group,
)
return returned_tokens
41 changes: 40 additions & 1 deletion torchtitan/experiments/llama4/infra/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.parallel import (
ParallelStyle,
PrepareModuleInputOutput,
)
from torch.distributed.tensor.placement_types import Placement


Expand Down Expand Up @@ -141,3 +144,39 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
),
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
)


class ExpertParallel(ParallelStyle):

def __init__(self, ):
super().__init__()

@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
for inp in inputs:
if isinstance(inp, torch.Tensor):
assert not isinstance(
inp, DTensor), "ExpertParallel expects local tensor inputs."
return inputs

def _partition_fn(self, name, module, device_mesh: DeviceMesh):
# shard on the expert dimension
for name, param in module.named_parameters(recurse=False):
dist_param = nn.Parameter(
distribute_tensor(param, device_mesh, [Shard(0)]))
module.register_parameter(name, dist_param)

@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh):
assert not isinstance(
outputs, DTensor), "ExpertParallel expects local tensor outputs."
return outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
self._prepare_input_fn,
self._prepare_output_fn,
)
78 changes: 60 additions & 18 deletions torchtitan/experiments/llama4/infra/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def parallelize_llama(
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)

apply_moe_tp(model, world_mesh["tp"])
apply_moe_tp(
model,
world_mesh["tp"],
enable_tp2ep=job_config.parallelism.enable_tp2ep,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Expand Down Expand Up @@ -145,32 +149,70 @@ def _sync_tokens_per_expert(module, *_):
def apply_moe_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
enable_tp2ep: bool = False,
):
from torch.distributed.tensor import Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
parallelize_module,
PrepareModuleInputOutput,
)

from .expert_parallel import NoParallel, TensorParallel
from .expert_parallel import (
NoParallel,
TensorParallel,
ExpertParallel,
)

for transformer_block in model.layers.values():
moe_layer_plan = {
# input / output sharding on the seqlen dim
# all-gather for input, reduce-scatter for output
"moe": PrepareModuleInputOutput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
use_local_input=True,
output_layouts=(Partial(),),
desired_output_layouts=(Shard(1),),
),
# replicate computation for the router
"moe.router.gate": NoParallel(),
# input Replicate, output Partial
"moe.experts": TensorParallel(output_layout=Partial()),
"moe.shared_expert": TensorParallel(output_layout=Partial()),
}
if enable_tp2ep:
moe_layer_plan = {
# input / output sharding on the seqlen dim
"moe":
PrepareModuleInputOutput(
input_layouts=(Shard(1), ),
desired_input_layouts=(Shard(1), ),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the input to router is sharded. Then this might break the semantics / correctness of the load balancing algorithm, given the update to self.tokens_per_expert is local to each EP rank.
https://github.com/pytorch/torchtitan/pull/1269/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R293

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out this issue. We need an all_reduce across all ep groups.

Fixed in b87aa1e

use_local_input=True,
output_layouts=(Shard(1), ),
desired_output_layouts=(Shard(1), ),
),
# FIXME: The input is reshaped after sharded along
# the seqlen dimension. Should we use local tensors
# instead of Replicate?
"moe.router.gate":
NoParallel(),
# Given the tokens are not splitted evenly,
# we need to use local tensors for both input / output.
# After the manual all-to-all gather, the result is
# sharded along the seqlen dim.
"moe.experts":
ExpertParallel(),
"moe.shared_expert":
TensorParallel(
input_layouts=(Shard(1), None),
output_layout=Shard(1),
),
}
else:
moe_layer_plan = {
# input / output sharding on the seqlen dim
# all-gather for input, reduce-scatter for output
"moe":
PrepareModuleInputOutput(
input_layouts=(Shard(1), ),
desired_input_layouts=(Replicate(), ),
use_local_input=True,
output_layouts=(Partial(), ),
desired_output_layouts=(Shard(1), ),
),
# replicate computation for the router
"moe.router.gate":
NoParallel(),
# input Replicate, output Partial
"moe.experts":
TensorParallel(output_layout=Partial()),
"moe.shared_expert":
TensorParallel(output_layout=Partial()),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down
Loading