-
Notifications
You must be signed in to change notification settings - Fork 407
[llama4] enable expert parallel on the same device mesh as tp (tp2ep) #1269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hann-wang
wants to merge
8
commits into
pytorch:main
Choose a base branch
from
hann-wang:han/pr_llama4_expert_parallel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
7aabc71
[llama4] enable expert parallel on the same device mesh as tp (tp2ep)
hann-wang f28ffeb
Merge branch 'pytorch:main' into han/pr_llama4_expert_parallel
hann-wang 18c5d17
refactor: move tp2ep communications into TokenDispatcher
hann-wang d727a91
Merge branch 'main' into han/pr_llama4_expert_parallel
hann-wang 864ee31
fix: torch.compile failure of TorchAllToAllTokenDispatcher
hann-wang b87aa1e
fix: expert bias update
hann-wang cc7a45c
chore: in-place add
hann-wang cd1680f
fix: multiply scores in FP32 datatype
hann-wang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ build | |
outputs | ||
dist/* | ||
.vscode | ||
slurm-*.out | ||
|
||
# data | ||
data | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import Tuple | ||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed._functional_collectives import all_to_all_single_autograd | ||
|
||
|
||
class DefaultTokenDispatcher: | ||
|
||
def __init__(self, num_experts: int, ep_size: int = 1): | ||
self.num_experts = num_experts | ||
self.ep_size = ep_size | ||
self.experts_per_rank = num_experts // ep_size | ||
self.ep_group = None | ||
|
||
def token_permutation( | ||
self, | ||
routed_input: torch.Tensor, | ||
top_scores: torch.Tensor, | ||
num_local_tokens_per_expert: torch.Tensor, | ||
training: bool = True, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, | ||
torch.Tensor | None]: | ||
return routed_input, top_scores, num_local_tokens_per_expert, None, None | ||
|
||
def token_unpermutation( | ||
self, | ||
routed_output: torch.Tensor, | ||
input_splits: torch.Tensor | None = None, | ||
output_splits: torch.Tensor | None = None, | ||
training: bool = True, | ||
) -> torch.Tensor: | ||
return routed_output | ||
|
||
|
||
class TorchAllToAllTokenDispatcher(DefaultTokenDispatcher): | ||
|
||
def __init__( | ||
self, | ||
num_experts: int, | ||
ep_size: int, | ||
ep_group: torch.distributed.ProcessGroup, | ||
): | ||
super().__init__(num_experts, ep_size) | ||
self.ep_group = ep_group | ||
|
||
def token_permutation( | ||
self, | ||
routed_input: torch.Tensor, | ||
top_scores: torch.Tensor, | ||
num_local_tokens_per_expert: torch.Tensor, | ||
training: bool = True, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, | ||
torch.Tensor | None]: | ||
dim = routed_input.shape[-1] | ||
with torch.no_grad(): | ||
tokens_per_expert_group = num_local_tokens_per_expert.new_empty( | ||
num_local_tokens_per_expert.shape[0]) | ||
dist.all_to_all_single(tokens_per_expert_group, | ||
num_local_tokens_per_expert, | ||
group=self.ep_group) | ||
input_splits = num_local_tokens_per_expert.view( | ||
self.ep_size, -1).sum(dim=1) | ||
output_splits = tokens_per_expert_group.view( | ||
self.ep_size, -1).sum(dim=1) | ||
if training: | ||
gathered_tokens = all_to_all_single_autograd( | ||
routed_input, | ||
output_splits.tolist(), | ||
input_splits.tolist(), | ||
self.ep_group, | ||
) | ||
gathered_top_scores = all_to_all_single_autograd( | ||
top_scores, | ||
output_splits.tolist(), | ||
input_splits.tolist(), | ||
self.ep_group, | ||
) | ||
else: | ||
# TODO: unify with all_to_all_single_autograd after | ||
# https://github.com/pytorch/pytorch/issues/154370 is resolved | ||
gathered_num_tokens = output_splits.sum() | ||
gathered_tokens = routed_input.new_empty( | ||
(gathered_num_tokens, dim)) | ||
dist.all_to_all_single( | ||
gathered_tokens, | ||
routed_input, | ||
output_splits.tolist(), | ||
input_splits.tolist(), | ||
group=self.ep_group, | ||
) | ||
gathered_top_scores = top_scores.new_empty(gathered_num_tokens, ) | ||
dist.all_to_all_single( | ||
gathered_top_scores, | ||
top_scores, | ||
output_splits.tolist(), | ||
input_splits.tolist(), | ||
group=self.ep_group, | ||
) | ||
return gathered_tokens, gathered_top_scores, tokens_per_expert_group, input_splits, output_splits | ||
|
||
def token_unpermutation( | ||
self, | ||
routed_output: torch.Tensor, | ||
input_splits: torch.Tensor | None = None, | ||
output_splits: torch.Tensor | None = None, | ||
training: bool = True, | ||
) -> torch.Tensor: | ||
dim = routed_output.shape[-1] | ||
if training: | ||
returned_tokens = all_to_all_single_autograd( | ||
routed_output, | ||
input_splits.tolist(), | ||
output_splits.tolist(), | ||
self.ep_group, | ||
) | ||
else: | ||
# TODO: unify with all_to_all_single_autograd after | ||
# https://github.com/pytorch/pytorch/issues/154370 is resolved | ||
returned_tokens = routed_output.new_empty( | ||
(input_splits.sum(), dim)) | ||
dist.all_to_all_single( | ||
returned_tokens, | ||
routed_output, | ||
input_splits.tolist(), | ||
output_splits.tolist(), | ||
group=self.ep_group, | ||
) | ||
return returned_tokens |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the input to router is sharded. Then this might break the semantics / correctness of the load balancing algorithm, given the update to
self.tokens_per_expert
is local to each EP rank.https://github.com/pytorch/torchtitan/pull/1269/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R293
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out this issue. We need an
all_reduce
across all ep groups.Fixed in b87aa1e