Skip to content

Commit 5e86195

Browse files
committed
support deepseek explicit shard mode
1 parent c0abc4c commit 5e86195

16 files changed

+698
-383
lines changed

src/MaxText/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
PREFILL_LENGTH = "prefill_activation_length"
3838
Q_LENGTH = "activation_q_length"
3939
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
40+
Q_LORA_UP_PROJ = "q_lora_up_proj"
4041
KV_LENGTH = "activation_kv_length"
42+
KV_LORA_UP_PROJ = "kv_lora_up_proj"
4143
EMBED = "activation_embed"
4244
HEAD = "activation_heads"
4345
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/MaxText/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1848,7 +1848,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
18481848
if self.packing:
18491849
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
18501850
if self.shard_mode == ShardMode.EXPLICIT:
1851-
supported_decoders = {"simple", "simple_mlp", "llama2"}
1851+
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
18521852
if self.decoder_block.value not in supported_decoders:
18531853
raise ValueError(
18541854
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "

src/MaxText/gradient_accumulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def convert_to_bf16(param):
9292
def accumulate_gradient(acc_grad_and_loss, data):
9393
ga_params = acc_grad_and_loss["ga_params"]
9494
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True)
95+
cur_batch_gradient = jax.tree.map(_maybe_shard_with_name, cur_batch_gradient, grad_shardings)
9596
acc_grad_and_loss["loss"] += aux["total_loss"]
9697
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
9798
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]

src/MaxText/layers/attention_mla.py

Lines changed: 111 additions & 112 deletions
Large diffs are not rendered by default.

src/MaxText/layers/attentions.py

Lines changed: 91 additions & 123 deletions
Large diffs are not rendered by default.

src/MaxText/layers/deepseek.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19+
from functools import partial
20+
1921
from jax.ad_checkpoint import checkpoint_name
20-
from jax.sharding import Mesh
22+
from jax.sharding import Mesh, NamedSharding
2123
import jax.numpy as jnp
2224

2325
from flax import linen as nn
@@ -31,6 +33,7 @@
3133
from MaxText.layers import quantizations
3234
from MaxText.layers.quantizations import AqtQuantization as Quant
3335
from MaxText.inference import page_manager
36+
from MaxText.sharding import maybe_shard_with_logical
3437
from MaxText.common_types import MODEL_MODE_PREFILL
3538

3639
# -----------------------------------------
@@ -66,8 +69,14 @@ def self_attention_with_norm(
6669
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
6770
else:
6871
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
72+
_maybe_shard_with_logical = partial(
73+
maybe_shard_with_logical,
74+
mesh=mesh,
75+
shard_mode=cfg.shard_mode,
76+
)
77+
lnx_out_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axis_names))
6978

70-
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
79+
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
7180

7281
attention_layer = attention_mla.mla_as_linen(
7382
config=cfg,
@@ -106,12 +115,13 @@ def self_attention_with_norm(
106115
decoder_segment_ids=decoder_segment_ids,
107116
deterministic=deterministic,
108117
model_mode=model_mode,
118+
out_sharding=lnx_out_sharding,
109119
previous_chunk=previous_chunk,
110120
page_state=page_state,
111121
slot=slot,
112122
)
113123

114-
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
124+
attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names)
115125
intermediate_inputs = inputs + attention_lnx
116126

117127
# Normalization
@@ -123,7 +133,7 @@ def self_attention_with_norm(
123133
kernel_axes=("norm",),
124134
epsilon=cfg.normalization_layer_epsilon,
125135
)(intermediate_inputs)
126-
hidden_states = nn.with_logical_constraint(hidden_states, logical_axis_names)
136+
hidden_states = _maybe_shard_with_logical(hidden_states, logical_axis_names)
127137
return hidden_states, intermediate_inputs
128138

129139

@@ -167,9 +177,14 @@ def __call__(
167177
cfg = self.config
168178
if model_mode == MODEL_MODE_PREFILL:
169179
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
180+
mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp")
170181
else:
171182
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
172-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
183+
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
184+
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
185+
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
186+
mlp_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
187+
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
173188
inputs = checkpoint_name(inputs, "decoder_layer_input")
174189

175190
hidden_states, intermediate_inputs = self_attention_with_norm(
@@ -196,12 +211,17 @@ def __call__(
196211
config=cfg,
197212
mesh=self.mesh,
198213
quant=self.quant,
199-
)(hidden_states, deterministic=deterministic)
200-
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
214+
)(
215+
hidden_states,
216+
deterministic=deterministic,
217+
intermediate_sharding=mlp_intermediate_sharding,
218+
out_sharding=lnx_out_sharding,
219+
)
220+
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
201221

202222
layer_output = mlp_lnx + intermediate_inputs
203223
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
204-
layer_output = nn.with_logical_constraint(
224+
layer_output = _maybe_shard_with_logical(
205225
layer_output,
206226
logical_axis_names,
207227
)
@@ -234,9 +254,14 @@ def __call__(
234254
cfg = self.config
235255
if model_mode == MODEL_MODE_PREFILL:
236256
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
257+
mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp")
237258
else:
238259
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
239-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
260+
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
261+
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
262+
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
263+
lnx_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
264+
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
240265
inputs = checkpoint_name(inputs, "decoder_layer_input")
241266

242267
hidden_states, intermediate_inputs = self_attention_with_norm(
@@ -265,12 +290,12 @@ def __call__(
265290
dtype=cfg.dtype,
266291
weight_dtype=cfg.weight_dtype,
267292
quant=self.quant,
268-
)(hidden_states)
269-
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
293+
)(hidden_states, intermediate_sharding=lnx_intermediate_sharding, out_sharding=lnx_out_sharding)
294+
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
270295

271296
layer_output = mlp_lnx + intermediate_inputs
272297
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
273-
layer_output = nn.with_logical_constraint(
298+
layer_output = _maybe_shard_with_logical(
274299
layer_output,
275300
logical_axis_names,
276301
)

src/MaxText/layers/embeddings.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,11 @@ def __init__(
263263
self,
264264
min_timescale: int,
265265
max_timescale: int,
266+
mesh: Mesh,
266267
embedding_dims: int = 0,
267268
cast_as_fprop_dtype: bool = True,
268269
fprop_dtype: DType = jnp.bfloat16,
270+
shard_mode: ShardMode = ShardMode.AUTO,
269271
# Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
270272
# TODO: Remove when bridge no longer needed
271273
rope_linear_scaling_factor: float = 1.0,
@@ -285,9 +287,11 @@ def __init__(
285287
"""
286288
self.min_timescale = min_timescale
287289
self.max_timescale = max_timescale
290+
self.mesh = mesh
288291
self.embedding_dims = embedding_dims
289292
self.cast_as_fprop_dtype = cast_as_fprop_dtype
290293
self.fprop_dtype = fprop_dtype
294+
self.shard_mode = shard_mode
291295
self.rope_linear_scaling_factor = rope_linear_scaling_factor
292296

293297
if self.embedding_dims % 2:
@@ -384,6 +388,7 @@ def qwen3_next_rotary_embedding_as_linen(
384388
*,
385389
min_timescale: int,
386390
max_timescale: int,
391+
mesh: Mesh,
387392
embedding_dims: int = 0,
388393
partial_rotary_factor: float = 0.25,
389394
cast_as_fprop_dtype: bool = True,
@@ -407,6 +412,7 @@ def qwen3_next_rotary_embedding_as_linen(
407412
Qwen3NextRotaryEmbedding,
408413
min_timescale=min_timescale,
409414
max_timescale=max_timescale,
415+
mesh=mesh,
410416
embedding_dims=embedding_dims,
411417
partial_rotary_factor=partial_rotary_factor,
412418
cast_as_fprop_dtype=cast_as_fprop_dtype,
@@ -423,6 +429,7 @@ def __init__(
423429
self,
424430
min_timescale: int,
425431
max_timescale: int,
432+
mesh: Mesh,
426433
embedding_dims: int = 0,
427434
cast_as_fprop_dtype: bool = True,
428435
fprop_dtype: DType = jnp.bfloat16,
@@ -447,6 +454,7 @@ def __init__(
447454
super().__init__(
448455
min_timescale=min_timescale,
449456
max_timescale=max_timescale,
457+
mesh=mesh,
450458
embedding_dims=self.rotary_dim,
451459
cast_as_fprop_dtype=cast_as_fprop_dtype,
452460
fprop_dtype=fprop_dtype,
@@ -478,10 +486,12 @@ def __init__(
478486
self,
479487
min_timescale: int,
480488
max_timescale: int,
489+
mesh: Mesh,
481490
embedding_dims: int = 0,
482491
cast_as_fprop_dtype: bool = True,
483492
fprop_dtype: DType = jnp.bfloat16,
484493
use_scale: bool = True,
494+
shard_mode: ShardMode = ShardMode.AUTO,
485495
# Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen.
486496
# TODO: Remove when bridge no longer needed
487497
rngs: nnx.Rngs = None,
@@ -505,6 +515,8 @@ def __init__(
505515
embedding_dims=embedding_dims,
506516
cast_as_fprop_dtype=cast_as_fprop_dtype,
507517
fprop_dtype=fprop_dtype,
518+
mesh=mesh,
519+
shard_mode=shard_mode,
508520
rngs=rngs,
509521
)
510522

@@ -613,6 +625,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
613625
def yarn_rotary_embedding_as_linen(
614626
*,
615627
embedding_dims: int,
628+
mesh: Mesh,
616629
max_position_embeddings: int = 4096 * 4,
617630
original_max_position_embeddings: int = 4096,
618631
beta_fast: float = 32,
@@ -625,6 +638,7 @@ def yarn_rotary_embedding_as_linen(
625638
interleave: bool = True,
626639
truncate: bool = True,
627640
attention_scaling: bool = False,
641+
shard_mode: ShardMode = ShardMode.AUTO,
628642
):
629643
"""Initializes the YarnRotaryEmbedding module and returns it as a Linen module.
630644
@@ -643,6 +657,7 @@ def yarn_rotary_embedding_as_linen(
643657
return nnx_wrappers.to_linen(
644658
YarnRotaryEmbedding,
645659
embedding_dims=embedding_dims,
660+
mesh=mesh,
646661
max_position_embeddings=max_position_embeddings,
647662
original_max_position_embeddings=original_max_position_embeddings,
648663
beta_fast=beta_fast,
@@ -656,6 +671,7 @@ def yarn_rotary_embedding_as_linen(
656671
interleave=interleave,
657672
truncate=truncate,
658673
attention_scaling=attention_scaling,
674+
shard_mode=shard_mode,
659675
)
660676

661677

@@ -685,6 +701,7 @@ class YarnRotaryEmbedding(nnx.Module):
685701
def __init__(
686702
self,
687703
embedding_dims: int,
704+
mesh: Mesh,
688705
max_position_embeddings: int = 4096 * 4,
689706
original_max_position_embeddings: int = 4096,
690707
beta_fast: float = 32,
@@ -693,6 +710,7 @@ def __init__(
693710
rope_factor: float = 40,
694711
cast_as_fprop_dtype: bool = True,
695712
fprop_dtype: DType = jnp.bfloat16,
713+
shard_mode: ShardMode = ShardMode.AUTO,
696714
interleave=True,
697715
truncate=True,
698716
attention_scaling=False,
@@ -712,6 +730,13 @@ def __init__(
712730
self.fprop_dtype = fprop_dtype
713731
self.interleave = interleave
714732
self.truncate = truncate
733+
self.mesh = mesh
734+
self.shard_mode = shard_mode
735+
self.freqs_sharding = (
736+
NamedSharding(mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "q_heads")))
737+
if shard_mode == ShardMode.EXPLICIT
738+
else None
739+
)
715740
self.attention_scaling = attention_scaling
716741

717742
if self.embedding_dims % 2:
@@ -811,7 +836,8 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
811836
# Lookup the precomputed frequencies using the position indices.
812837
# self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
813838
# After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads.
814-
freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
839+
# freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
840+
freqs = self.freqs_cis.at[position].get(out_sharding=self.freqs_sharding)
815841
freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim]
816842

817843
if self.interleave:
@@ -828,7 +854,14 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
828854

829855
inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
830856
# Apply the rotary transformation via complex multiplication.
831-
rotated = inputs_complex * freqs # shape: [B, S, N, half_dim]
857+
rotated_sharding = (
858+
NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None)))
859+
if self.shard_mode == ShardMode.EXPLICIT
860+
else None
861+
)
862+
rotated = jnp.einsum(
863+
"ijkl, ijml->ijkl", inputs_complex, freqs, out_sharding=rotated_sharding
864+
) # shape: [B, S, N, half_dim]
832865
# Convert the complex result back to a real tensor.
833866
# Split the complex number into its real and imaginary parts.
834867
# [real1, real2, ..., img1, img2, ...]

0 commit comments

Comments
 (0)