Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Nov 1, 2025

Description

New replacement PR: #2783

In this PR we made following changes:

  1. Add sharding constraints/reshards for explicit sharding bug-free;
  2. Add attention tests on explicit sharding, including expert parallelism, context parallelism, and mixture of two
  3. Refactor logical names/sharding constraints initialization in attention.py, attention_mla.py.

Tests

Model_name: deepseek3-test
Topology: v5p-8
JAX==0.8.0
TL;DR:

  • Auto Sharding: Maintains performance parity with the main branch (deviations ≤ 0.3%).
  • Explicit Sharding: While this introduces a slight expected overhead in standard cases, it provides significant improvements for TP_transpose sharding.

Performance Impact Table

Configuration Main Peak Memory (Mb) Main Step Time (ms) Auto Peak Memory (Mb) Auto Peak Mem Change (%) Auto Step Time (ms) Auto Step Time Change (%) Explicit Peak Memory (Mb) Explicit Peak Mem Change (%) Explicit Step Time (ms) Explicit Step Time Change (%)
FSDP4 68748 1822 68748 0% 1822 0% 68961 0.31% 1870 2.63%
FSDP2+TP2 69447 1427 69445 0% 1431 0.28% 69616 0.24% 1796 25.86%
FSDP2+TPT2 93835 10382 93836 0% 10385 0.03% 69372 -26.07% 2179 -79.01%
FSDP2+CP2 69296 2340 69476 0.26% 2340 0% 69453 0.23% 2395 2.35%
FSDP2+EP2 (fsdp) 68653 1718 68656 0% 1719 0.06% 68846 0.28% 1823 6.11%
FSDP2+EP2 (context) 69240 2293 69240 0% 2293 0% 69263 0.03% 2506 9.29%

Full FSDP

smoke_train model_name=deepseek3-test per_device_batch_size=1

FSDP + TP

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_parallelism=2

FSDP + TP_transpose

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_transpose_parallelism=2

FSDP + CP

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_context_parallelism=2 dataset_type=synthetic

FSDP + EP(fsdp)

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_expert_parallelism=2  expert_shard_attention_option=fsdp

FSDP + EP(context)

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_expert_parallelism=2  expert_shard_attention_option=context dataset_type=synthetic

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Nov 1, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek branch 8 times, most recently from 97cc3bd to 0b4a8b5 Compare November 6, 2025 19:30
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek branch 11 times, most recently from 7b57bd2 to c01b965 Compare November 18, 2025 02:41
@NuojCheng NuojCheng added gemini-review and removed draft Draft PR labels Nov 18, 2025
@NuojCheng NuojCheng marked this pull request as ready for review November 18, 2025 03:08
@github-actions
Copy link

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces explicit sharding for the DeepSeek model, which is a significant and valuable enhancement. The changes are extensive, refactoring various layers (attention, moe, llama2, etc.) to support explicit sharding, which improves code structure and maintainability by centralizing sharding logic. The addition of comprehensive tests for different parallelism configurations with explicit sharding is also a great inclusion.

🔍 General Feedback

  • The refactoring to use dataclasses for logical axis names is a good pattern that makes the sharding logic cleaner and easier to follow across different layers.
  • I've identified one critical copy-paste error in the MoE layer and a minor inconsistency in the test setup. Once these are addressed, this PR will be in excellent shape.

@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek branch 2 times, most recently from 5e86195 to 81b117a Compare November 18, 2025 05:38
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NuojCheng for the change! Let's try to merge after 11/24 timeline? which may potentially introduce the chaos for the current perf. We did similar decision for NNX migration.

Review is not blocked for sure.

cc @suexu1025 @parmitam

@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek branch from 1a711c4 to 679a5ef Compare November 25, 2025 02:37
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Nuojin for the change!

This is a huge PR, which I haven't finished reading yet. Do you think it's possible to:

  1. split this PR into onboarding explicit sharding and refactor?
  2. I see you have tested auto-mode and explicit mode, which is great! Have you tested before/after change for auto mode? We don't need to test all shardings like you have done already, but just ensure no perf regression for auto mode for DS v3 with basic FSDP, TP, EP sharding.


contract_ind = tuple(range(0, len(norm_axis)))

# [B, S, E] -> [B, S, num_exp]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to shard embedding dimension to be expert?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output_sharding here specifies the logical axis of the matmul output, which is supposed to be B x S x NumExpert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have to think more but I think we want to leave the NumExpert dimension unsharded. The batch dimension should still be sharded by EP here (batch is sharded by EP until after the a2a)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is inside of Gating. So NumExpert will be starting to be sharded after 1st A2A and before 2nd A2A.

if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
else:
# pre_bias_logits is None for non-DeepSeek v3 models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let remove this comment, as if/else already indicates this logic right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still there?

# mlp_no_fsdp axis
if self.config.fsdp_shard_on_exp:
# special sharding for dsv3
wi_kernel_axes = ("embed_no_exp", None, "mlp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

      wi_kernel_axes = ("embed_no_exp", None, "mlp")
      wo_kernel_axes = ("embed_no_exp", "mlp", None)

Can be moved to quantization block?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think wi_kernel_axes and wi_kernel_axes_sp are used for different purposes. The later one is only used for sparse_matmul. For non-quantized+dense_matmul we stil use wi_kernel_axes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I was a little bit confused. If so, could we rename them to wi_kernel_axes_sparse and wi_kernel_axes_dense for readability when creating them in one place?

@NuojCheng
Copy link
Collaborator Author

NuojCheng commented Dec 2, 2025

Thank you Ran for the reviews! @RissyRan

  1. split this PR into onboarding explicit sharding and refactor?

I can do that.

  1. I see you have tested auto-mode and explicit mode, which is great! Have you tested before/after change for auto mode? We don't need to test all shardings like you have done already, but just ensure no perf regression for auto mode for DS v3 with basic FSDP, TP, EP sharding.

The performance test results between main and shard_mode=auto are included in the description table, under the column of Auto Peak Mem Change (%) and Auto Step Time Change (%).



@dataclasses.dataclass(frozen=True)
class AttentionMLALogicalNames:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temperature_tuning: bool = False,
temperature_tuning_scale: float = 0.1,
temperature_tuning_floor_scale: float = 8192.0,
# Shard the query activation as the same as the key and value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape)
self.out = self.init_out_w(output_dim=inputs_q_shape[-1])

def _create_sharding(self, axis_names):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decoder_positions = reordered_batch["inputs_position"]
# apply attention with sharding
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof we generally prefer to avoid logic in tests but I think I understand why this is here. Note this is a case that hints we should be modifying the logical axis rules at config parsing time instead of at runtime

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decoder_positions = reordered_batch["inputs_position"]
# apply attention with sharding
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! A few minor comments.


def setup_sharding(self, model_mode):
self.logical_names = self._get_logical_names(model_mode)
self.inputs_sharding = self._create_sharding(self.logical_names.inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why those sharding uses _create_sharding instead of maybe_shard_with_logical? Which seems to return None in the auto mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is some sharding-related refactoring components. Let's pause it for now until we have design doc and more discussion on it.

This is a common optimization to simplify sharding by excluding redundant axes.
Copied from jax._src.core:
https://github.com/jax-ml/jax/blob/db65f4548d10493613ee4155fd463a17ab5680ddc/jax/_src/core.py#L2139-L2148
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link is not found? https://screenshot.googleplex.com/48Vou6KgByUuGkQ You could generate the permanent link instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching it. This component is supposed to in another PR #2737 . I removed this part in the new PR #2783

if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
else:
# pre_bias_logits is None for non-DeepSeek v3 models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still there?

# mlp_no_fsdp axis
if self.config.fsdp_shard_on_exp:
# special sharding for dsv3
wi_kernel_axes = ("embed_no_exp", None, "mlp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I was a little bit confused. If so, could we rename them to wi_kernel_axes_sparse and wi_kernel_axes_dense for readability when creating them in one place?


contract_ind = tuple(range(0, len(norm_axis)))

# [B, S, E] -> [B, S, num_exp]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is inside of Gating. So NumExpert will be starting to be sharded after 1st A2A and before 2nd A2A.

@NuojCheng NuojCheng closed this Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants