-
Notifications
You must be signed in to change notification settings - Fork 432
Onboard explicit sharding to deepseek [split version] #2783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1279302 to
9d4d6b4
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for simplifying it! Are you planning to merge this one instead?
If you could directly apply a few minor moments from old PR (just published). We should be good to go.
|
Reply to @RissyRan:
Yes let's merge this one instead. The sharding refactoring parts are removed here.
Thank you for the comments! The concerns from the previous PR are all related to the sharding refactoring, which are not included in this PR. |
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR introduces explicit sharding for the DeepSeek model, which is a significant and valuable feature for performance and scalability. The changes are extensive and well-structured. I've identified a couple of minor logical issues and potential improvements in the implementation.
🔍 General Feedback
- The introduction of
shard_modeand the refactoring to use_maybe_shard_with_logicalis a clean way to handle explicit sharding. - The tests have been updated comprehensively, which is great to see.
- Pay close attention to copy-paste errors, especially when dealing with similar variables like
inputs_qandinputs_kv.
Overall, this is a solid contribution. Once the minor issues are addressed, this PR will be in great shape.
| query = self._maybe_shard_with_logical(query, query_logical_name) | ||
| return query | ||
|
|
||
| def mla_get_key_value(self, low_rank_main, key_rope, model_mode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the return statement to the end of the function and removed the now-unnecessary final _maybe_shard_with_logical on query as it's already sharded correctly within the conditional blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the return query in line 544 is in an if-else statement. It is a little bit misleading.
| input_axis_names = self.ep_input_axis_names | ||
| elif model_mode == MODEL_MODE_TRAIN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Updated.
| rotated_sharding = ( | ||
| NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None))) | ||
| if self.shard_mode == ShardMode.EXPLICIT | ||
| else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I will do jnp.broadcast_to and jnp.multiply to replace the jnp.einsum.
9d4d6b4 to
8d2efc0
Compare
8d2efc0 to
3dc9988
Compare
Description
This PR introduces explicit sharding support for the DeepSeek model in MaxText. This allows for more granular control over how tensors are distributed across devices, which may lead to better performance and scalability.
This PR is a simplified version based on #2579 .
Key Changes:
deepseekdecoder is now supported inShardMode.EXPLICIT.shard_mode: Ashard_modeparameter has been added to various layers, including attention, embeddings, and MoE, to control sharding behavior. This allows for more flexible and explicit sharding configurations.NamedShardingobjects, making the sharding implementation more explicit and maintainable.Detailed Changes by File:
src/MaxText/common_types.py: AddedQ_LORA_UP_PROJandKV_LORA_UP_PROJto support LoRA projections with explicit sharding.src/MaxText/configs/types.py: Enabled thedeepseekdecoder for explicit sharding.src/MaxText/layers/attention_mla.py: Updated to useout_shardingand_maybe_shard_with_logicalfor more explicit control over sharding in MLA.src/MaxText/layers/attentions.py: Added theshard_modeparameter toRotaryEmbeddingand its variants. Refactored input and projection sharding to be more explicit.src/MaxText/layers/deepseek.py: Integrated explicit sharding into theDeepSeekDecoderLayerby using_maybe_shard_with_logicaland passingout_shardingandintermediate_shardingto sub-layers.src/MaxText/layers/embeddings.py: Added theshard_modeparameter to all embedding classes to allow for explicit sharding configuration.src/MaxText/layers/moe.py: Addedshard_modetoGateLogitand updated theMoeBlockto handle explicit sharding for weights and activations.src/MaxText/sharding.py: Introduced new utility functions for explicit sharding.tests/*: Updated various tests to include test cases for explicit sharding and to pass the newshard_modeandmeshparameters where necessary.Tests
Model_name: deepseek3-test
Topology: v5p-8
JAX==0.8.1
TL;DR:
Performance Impact Table
Full FSDP
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + TP
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + TP_transpose
shard_mode=auto: xprofshard_mode=explicit: xprofChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.