-
Notifications
You must be signed in to change notification settings - Fork 435
Onboard explicit sharding to deepseek #2579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
97cc3bd to
0b4a8b5
Compare
7b57bd2 to
c01b965
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces explicit sharding for the DeepSeek model, which is a significant and valuable enhancement. The changes are extensive, refactoring various layers (attention, moe, llama2, etc.) to support explicit sharding, which improves code structure and maintainability by centralizing sharding logic. The addition of comprehensive tests for different parallelism configurations with explicit sharding is also a great inclusion.
🔍 General Feedback
- The refactoring to use dataclasses for logical axis names is a good pattern that makes the sharding logic cleaner and easier to follow across different layers.
- I've identified one critical copy-paste error in the MoE layer and a minor inconsistency in the test setup. Once these are addressed, this PR will be in excellent shape.
5e86195 to
81b117a
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @NuojCheng for the change! Let's try to merge after 11/24 timeline? which may potentially introduce the chaos for the current perf. We did similar decision for NNX migration.
Review is not blocked for sure.
81b117a to
1a711c4
Compare
1a711c4 to
679a5ef
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nuojin for the change!
This is a huge PR, which I haven't finished reading yet. Do you think it's possible to:
- split this PR into onboarding explicit sharding and refactor?
- I see you have tested auto-mode and explicit mode, which is great! Have you tested before/after change for auto mode? We don't need to test all shardings like you have done already, but just ensure no perf regression for auto mode for DS v3 with basic FSDP, TP, EP sharding.
|
|
||
| contract_ind = tuple(range(0, len(norm_axis))) | ||
|
|
||
| # [B, S, E] -> [B, S, num_exp] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to shard embedding dimension to be expert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output_sharding here specifies the logical axis of the matmul output, which is supposed to be B x S x NumExpert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have to think more but I think we want to leave the NumExpert dimension unsharded. The batch dimension should still be sharded by EP here (batch is sharded by EP until after the a2a)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is inside of Gating. So NumExpert will be starting to be sharded after 1st A2A and before 2nd A2A.
| if self.config.model_name.startswith("deepseek3"): | ||
| pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) | ||
| else: | ||
| # pre_bias_logits is None for non-DeepSeek v3 models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let remove this comment, as if/else already indicates this logic right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still there?
| # mlp_no_fsdp axis | ||
| if self.config.fsdp_shard_on_exp: | ||
| # special sharding for dsv3 | ||
| wi_kernel_axes = ("embed_no_exp", None, "mlp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wi_kernel_axes = ("embed_no_exp", None, "mlp")
wo_kernel_axes = ("embed_no_exp", "mlp", None)
Can be moved to quantization block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think wi_kernel_axes and wi_kernel_axes_sp are used for different purposes. The later one is only used for sparse_matmul. For non-quantized+dense_matmul we stil use wi_kernel_axes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I was a little bit confused. If so, could we rename them to wi_kernel_axes_sparse and wi_kernel_axes_dense for readability when creating them in one place?
|
Thank you Ran for the reviews! @RissyRan
I can do that.
The performance test results between main and |
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class AttentionMLALogicalNames: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @richjames0
| temperature_tuning: bool = False, | ||
| temperature_tuning_scale: float = 0.1, | ||
| temperature_tuning_floor_scale: float = 8192.0, | ||
| # Shard the query activation as the same as the key and value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<3
| self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape) | ||
| self.out = self.init_out_w(output_dim=inputs_q_shape[-1]) | ||
|
|
||
| def _create_sharding(self, axis_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @richjames0
| decoder_positions = reordered_batch["inputs_position"] | ||
| # apply attention with sharding | ||
| with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): | ||
| if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oof we generally prefer to avoid logic in tests but I think I understand why this is here. Note this is a case that hints we should be modifying the logical axis rules at config parsing time instead of at runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @richjames0
| decoder_positions = reordered_batch["inputs_position"] | ||
| # apply attention with sharding | ||
| with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): | ||
| if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @richjames0
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! A few minor comments.
|
|
||
| def setup_sharding(self, model_mode): | ||
| self.logical_names = self._get_logical_names(model_mode) | ||
| self.inputs_sharding = self._create_sharding(self.logical_names.inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why those sharding uses _create_sharding instead of maybe_shard_with_logical? Which seems to return None in the auto mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is some sharding-related refactoring components. Let's pause it for now until we have design doc and more discussion on it.
| This is a common optimization to simplify sharding by excluding redundant axes. | ||
| Copied from jax._src.core: | ||
| https://github.com/jax-ml/jax/blob/db65f4548d10493613ee4155fd463a17ab5680ddc/jax/_src/core.py#L2139-L2148 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link is not found? https://screenshot.googleplex.com/48Vou6KgByUuGkQ You could generate the permanent link instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.config.model_name.startswith("deepseek3"): | ||
| pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) | ||
| else: | ||
| # pre_bias_logits is None for non-DeepSeek v3 models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still there?
| # mlp_no_fsdp axis | ||
| if self.config.fsdp_shard_on_exp: | ||
| # special sharding for dsv3 | ||
| wi_kernel_axes = ("embed_no_exp", None, "mlp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I was a little bit confused. If so, could we rename them to wi_kernel_axes_sparse and wi_kernel_axes_dense for readability when creating them in one place?
|
|
||
| contract_ind = tuple(range(0, len(norm_axis))) | ||
|
|
||
| # [B, S, E] -> [B, S, num_exp] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is inside of Gating. So NumExpert will be starting to be sharded after 1st A2A and before 2nd A2A.
Description
New replacement PR: #2783
In this PR we made following changes:
attention.py,attention_mla.py.Tests
Model_name: deepseek3-test
Topology: v5p-8
JAX==0.8.0
TL;DR:
Performance Impact Table
Full FSDP
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + TP
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + TP_transpose
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + CP
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + EP(fsdp)
shard_mode=auto: xprofshard_mode=explicit: xprofFSDP + EP(context)
shard_mode=auto: xprofshard_mode=explicit: xprofChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.