Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Dec 3, 2025

Description

This PR introduces explicit sharding support for the DeepSeek model in MaxText. This allows for more granular control over how tensors are distributed across devices, which may lead to better performance and scalability.

This PR is a simplified version based on #2579 .

Key Changes:

  • Enabled DeepSeek for Explicit Sharding: The deepseek decoder is now supported in ShardMode.EXPLICIT.
  • Introduced shard_mode: A shard_mode parameter has been added to various layers, including attention, embeddings, and MoE, to control sharding behavior. This allows for more flexible and explicit sharding configurations.
  • Refactored Sharding Logic: The existing sharding logic has been (minorly) refactored to use the new sharding utilities and NamedSharding objects, making the sharding implementation more explicit and maintainable.
  • Updated Tests: The test suite has been updated to include tests for explicit sharding, ensuring the correctness and robustness of the implementation.

Detailed Changes by File:

  • src/MaxText/common_types.py: Added Q_LORA_UP_PROJ and KV_LORA_UP_PROJ to support LoRA projections with explicit sharding.
  • src/MaxText/configs/types.py: Enabled the deepseek decoder for explicit sharding.
  • src/MaxText/layers/attention_mla.py: Updated to use out_sharding and _maybe_shard_with_logical for more explicit control over sharding in MLA.
  • src/MaxText/layers/attentions.py: Added the shard_mode parameter to RotaryEmbedding and its variants. Refactored input and projection sharding to be more explicit.
  • src/MaxText/layers/deepseek.py: Integrated explicit sharding into the DeepSeekDecoderLayer by using _maybe_shard_with_logical and passing out_sharding and intermediate_sharding to sub-layers.
  • src/MaxText/layers/embeddings.py: Added the shard_mode parameter to all embedding classes to allow for explicit sharding configuration.
  • src/MaxText/layers/moe.py: Added shard_mode to GateLogit and updated the MoeBlock to handle explicit sharding for weights and activations.
  • src/MaxText/sharding.py: Introduced new utility functions for explicit sharding.
  • tests/*: Updated various tests to include test cases for explicit sharding and to pass the new shard_mode and mesh parameters where necessary.

Tests

Model_name: deepseek3-test
Topology: v5p-8
JAX==0.8.1
TL;DR:

  • Auto Sharding: Maintains performance parity with the main branch (deviations ≤ 0.3%).
  • Explicit Sharding: While this introduces a slight expected overhead in TP case, it provides significant improvements for TP_transpose sharding.

Performance Impact Table

Configuration Main Peak Memory (Mb) Main Step Time (ms) Auto Peak Memory (Mb) Auto Peak Mem Change (%) Auto Step Time (ms) Auto Step Time Change (%) Explicit Peak Memory (Mb) Explicit Peak Mem Change (%) Explicit Step Time (ms) Explicit Step Time Change (%)
FSDP4 68748 2238 68748 0% 2238 0% 68961 0.31% 2284 2.01%
FSDP2+TP2 69447 2009 69448 0% 2014 0.25% 69617 0.24% 2378 18.37%
FSDP2+TPT2 93835 10742 93836 0% 10743 0.01% 69373 -26.07% 2714 -74.68%

Full FSDP

smoke_train model_name=deepseek3-test per_device_batch_size=1

FSDP + TP

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_parallelism=2

FSDP + TP_transpose

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_transpose_parallelism=2

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for simplifying it! Are you planning to merge this one instead?

If you could directly apply a few minor moments from old PR (just published). We should be good to go.

@NuojCheng
Copy link
Collaborator Author

NuojCheng commented Dec 4, 2025

Reply to @RissyRan:

Are you planning to merge this one instead?

Yes let's merge this one instead. The sharding refactoring parts are removed here.

If you could directly apply a few minor moments from old PR (just published). We should be good to go.

Thank you for the comments! The concerns from the previous PR are all related to the sharding refactoring, which are not included in this PR.

@github-actions
Copy link

github-actions bot commented Dec 5, 2025

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces explicit sharding for the DeepSeek model, which is a significant and valuable feature for performance and scalability. The changes are extensive and well-structured. I've identified a couple of minor logical issues and potential improvements in the implementation.

🔍 General Feedback

  • The introduction of shard_mode and the refactoring to use _maybe_shard_with_logical is a clean way to handle explicit sharding.
  • The tests have been updated comprehensively, which is great to see.
  • Pay close attention to copy-paste errors, especially when dealing with similar variables like inputs_q and inputs_kv.

Overall, this is a solid contribution. Once the minor issues are addressed, this PR will be in great shape.

query = self._maybe_shard_with_logical(query, query_logical_name)
return query

def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The `mla_query_projection` function has a logic issue. The `return query` statement is inside the `if/else` block, making the following constraints unreachable. It seems like the return statement was misplaced during refactoring.

I've moved the return statement to the end of the function and removed the now-unnecessary final _maybe_shard_with_logical on query as it's already sharded correctly within the conditional blocks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the return query in line 544 is in an if-else statement. It is a little bit misleading.

Comment on lines +998 to 999
input_axis_names = self.ep_input_axis_names
elif model_mode == MODEL_MODE_TRAIN:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 There appears to be a copy-paste error here. `inputs_kv` is being sharded using `inputs_q`'s sharding information. It should use its own.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Updated.

Comment on lines +879 to +882
rotated_sharding = (
NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None)))
if self.shard_mode == ShardMode.EXPLICIT
else None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Using `jnp.einsum` for complex multiplication might be less clear and potentially less performant than direct element-wise multiplication with `*`. The previous implementation used `*`, which is more idiomatic for this operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will do jnp.broadcast_to and jnp.multiply to replace the jnp.einsum.

@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek-split branch from 9d4d6b4 to 8d2efc0 Compare December 5, 2025 23:50
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek-split branch from 8d2efc0 to 3dc9988 Compare December 6, 2025 01:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants