add support for simplefsdp+ep #1529
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As titled, this pr adds support for simplefsdp+ep.
In a SimpleFSDP + EP tlparse, you can see the all-to-all co-exists with replicated shared experts (_grouped_mm), which means we could potentially reorder them for overlapping.
Profiler Trace & Correctness (eager-mode)
The following results are benchmarks on 8 H100
Loss: As seen the losses almost match between
FSDP2+EP
andSimpleFSDP+EP
.Trace: The first stream is FSDP (on dp_shard_cp dim), the second stream is all_to_all in token dispatch, and the third stream is FSDP (on dp_shard_mod_ep). It's on different streams probably because its submesh names are different, but it should be or we hope they will be on the same....
Loss: As seen the losses almost match between
FSDP2+TP
andSimpleFSDP+TP
.Trace: The first stream is FSDP, the second stream is TP.
Loss: As seen the losses almost match between
FSDP2+TP+EP
andSimpleFSDP+TP+EP
.Trace: The first stream is FSDP communication (on dp_shard_cp dim), the second stream is TP communication, the third stream is all-to-all for token dispatch, and the fourth stream is FSDP communication (on dp_shard_mod_ep dim).
Loss: As seen the losses almost match between
FSDP2(HSDP)+TP+EP
andSimpleFSDP(HSDP)+TP+EP
.Trace: The first stream is FSDP communication (on dp_shard_cp dim), the second stream is TP communication, the third stream is all-to-all for token dispatch, the fourth stream is FSDP communication (on dp_shard_mod_ep dim), and the fifth stream is DDP communication.
What is not working:
self.input_splits = num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1)
in expert_parallel.py after confirming with @xmfan LINK;_A2A
class as the temporary fix for AC leak... LINKall_to_all_single_autograd
's input, which is converted usingtolist()
. But for some magical reason, I found if inputs to all_to_all_single is parsed asoutput_split_sizes.tolist()
instead ofoutput_split_sizes
, there will be no graph break.... After looking into the tlparse, output_split_sizes is still treated as a tensor, but when copy data out of output_split_sizes, the triton code will do an additional.item()
. LINKreshard_after_forward
is not working. Thus, the behavior of two are different.cc. @anijain2305