Skip to content

Commit adf511a

Browse files
committed
support deepseek explicit shard mode
1 parent 4c454f5 commit adf511a

File tree

15 files changed

+701
-392
lines changed

15 files changed

+701
-392
lines changed

src/MaxText/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
PREFILL_LENGTH = "prefill_activation_length"
3838
Q_LENGTH = "activation_q_length"
3939
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
40+
Q_LORA_UP_PROJ = "q_lora_up_proj"
4041
KV_LENGTH = "activation_kv_length"
42+
KV_LORA_UP_PROJ = "kv_lora_up_proj"
4143
EMBED = "activation_embed"
4244
HEAD = "activation_heads"
4345
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/MaxText/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
18891889
if self.packing:
18901890
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
18911891
if self.shard_mode == ShardMode.EXPLICIT:
1892-
supported_decoders = {"simple", "simple_mlp", "llama2"}
1892+
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
18931893
if self.decoder_block.value not in supported_decoders:
18941894
raise ValueError(
18951895
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "

src/MaxText/layers/attention_mla.py

Lines changed: 111 additions & 112 deletions
Large diffs are not rendered by default.

src/MaxText/layers/attentions.py

Lines changed: 91 additions & 123 deletions
Large diffs are not rendered by default.

src/MaxText/layers/deepseek.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19+
from functools import partial
20+
1921
from jax.ad_checkpoint import checkpoint_name
20-
from jax.sharding import Mesh
22+
from jax.sharding import Mesh, NamedSharding
2123
import jax.numpy as jnp
2224

2325
from flax import linen as nn
@@ -31,6 +33,7 @@
3133
from MaxText.layers import quantizations
3234
from MaxText.layers.quantizations import AqtQuantization as Quant
3335
from MaxText.inference import page_manager
36+
from MaxText.sharding import maybe_shard_with_logical
3437
from MaxText.common_types import MODEL_MODE_PREFILL
3538

3639
# -----------------------------------------
@@ -66,8 +69,14 @@ def self_attention_with_norm(
6669
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
6770
else:
6871
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
72+
_maybe_shard_with_logical = partial(
73+
maybe_shard_with_logical,
74+
mesh=mesh,
75+
shard_mode=cfg.shard_mode,
76+
)
77+
lnx_out_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axis_names))
6978

70-
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
79+
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
7180

7281
attention_layer = attention_mla.mla_as_linen(
7382
config=cfg,
@@ -106,12 +115,13 @@ def self_attention_with_norm(
106115
decoder_segment_ids=decoder_segment_ids,
107116
deterministic=deterministic,
108117
model_mode=model_mode,
118+
out_sharding=lnx_out_sharding,
109119
previous_chunk=previous_chunk,
110120
page_state=page_state,
111121
slot=slot,
112122
)
113123

114-
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
124+
attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names)
115125
intermediate_inputs = inputs + attention_lnx
116126

117127
# Normalization
@@ -123,7 +133,7 @@ def self_attention_with_norm(
123133
kernel_axes=("norm",),
124134
epsilon=cfg.normalization_layer_epsilon,
125135
)(intermediate_inputs)
126-
hidden_states = nn.with_logical_constraint(hidden_states, logical_axis_names)
136+
hidden_states = _maybe_shard_with_logical(hidden_states, logical_axis_names)
127137
return hidden_states, intermediate_inputs
128138

129139

@@ -169,9 +179,14 @@ def __call__(
169179
cfg = self.config
170180
if model_mode == MODEL_MODE_PREFILL:
171181
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
182+
mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp")
172183
else:
173184
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
174-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
185+
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
186+
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
187+
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
188+
mlp_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
189+
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
175190
inputs = checkpoint_name(inputs, "decoder_layer_input")
176191

177192
hidden_states, intermediate_inputs = self_attention_with_norm(
@@ -198,12 +213,17 @@ def __call__(
198213
config=cfg,
199214
mesh=self.mesh,
200215
quant=self.quant,
201-
)(hidden_states, deterministic=deterministic)
202-
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
216+
)(
217+
hidden_states,
218+
deterministic=deterministic,
219+
intermediate_sharding=mlp_intermediate_sharding,
220+
out_sharding=lnx_out_sharding,
221+
)
222+
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
203223

204224
layer_output = mlp_lnx + intermediate_inputs
205225
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
206-
layer_output = nn.with_logical_constraint(
226+
layer_output = _maybe_shard_with_logical(
207227
layer_output,
208228
logical_axis_names,
209229
)
@@ -238,9 +258,14 @@ def __call__(
238258
cfg = self.config
239259
if model_mode == MODEL_MODE_PREFILL:
240260
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
261+
mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp")
241262
else:
242263
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
243-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
264+
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
265+
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
266+
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
267+
lnx_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
268+
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
244269
inputs = checkpoint_name(inputs, "decoder_layer_input")
245270

246271
hidden_states, intermediate_inputs = self_attention_with_norm(
@@ -269,12 +294,12 @@ def __call__(
269294
dtype=cfg.dtype,
270295
weight_dtype=cfg.weight_dtype,
271296
quant=self.quant,
272-
)(hidden_states)
273-
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
297+
)(hidden_states, intermediate_sharding=lnx_intermediate_sharding, out_sharding=lnx_out_sharding)
298+
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
274299

275300
layer_output = mlp_lnx + intermediate_inputs
276301
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
277-
layer_output = nn.with_logical_constraint(
302+
layer_output = _maybe_shard_with_logical(
278303
layer_output,
279304
logical_axis_names,
280305
)

src/MaxText/layers/embeddings.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,11 @@ def __init__(
275275
self,
276276
min_timescale: int,
277277
max_timescale: int,
278+
mesh: Mesh,
278279
embedding_dims: int = 0,
279280
cast_as_fprop_dtype: bool = True,
280281
fprop_dtype: DType = jnp.bfloat16,
282+
shard_mode: ShardMode = ShardMode.AUTO,
281283
# Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
282284
# TODO: Remove when bridge no longer needed
283285
rope_linear_scaling_factor: float = 1.0,
@@ -297,9 +299,11 @@ def __init__(
297299
"""
298300
self.min_timescale = min_timescale
299301
self.max_timescale = max_timescale
302+
self.mesh = mesh
300303
self.embedding_dims = embedding_dims
301304
self.cast_as_fprop_dtype = cast_as_fprop_dtype
302305
self.fprop_dtype = fprop_dtype
306+
self.shard_mode = shard_mode
303307
self.rope_linear_scaling_factor = rope_linear_scaling_factor
304308

305309
if self.embedding_dims % 2:
@@ -396,6 +400,7 @@ def qwen3_next_rotary_embedding_as_linen(
396400
*,
397401
min_timescale: int,
398402
max_timescale: int,
403+
mesh: Mesh,
399404
embedding_dims: int = 0,
400405
partial_rotary_factor: float = 0.25,
401406
cast_as_fprop_dtype: bool = True,
@@ -419,6 +424,7 @@ def qwen3_next_rotary_embedding_as_linen(
419424
Qwen3NextRotaryEmbedding,
420425
min_timescale=min_timescale,
421426
max_timescale=max_timescale,
427+
mesh=mesh,
422428
embedding_dims=embedding_dims,
423429
partial_rotary_factor=partial_rotary_factor,
424430
cast_as_fprop_dtype=cast_as_fprop_dtype,
@@ -435,6 +441,7 @@ def __init__(
435441
self,
436442
min_timescale: int,
437443
max_timescale: int,
444+
mesh: Mesh,
438445
embedding_dims: int = 0,
439446
cast_as_fprop_dtype: bool = True,
440447
fprop_dtype: DType = jnp.bfloat16,
@@ -459,6 +466,7 @@ def __init__(
459466
super().__init__(
460467
min_timescale=min_timescale,
461468
max_timescale=max_timescale,
469+
mesh=mesh,
462470
embedding_dims=self.rotary_dim,
463471
cast_as_fprop_dtype=cast_as_fprop_dtype,
464472
fprop_dtype=fprop_dtype,
@@ -490,10 +498,12 @@ def __init__(
490498
self,
491499
min_timescale: int,
492500
max_timescale: int,
501+
mesh: Mesh,
493502
embedding_dims: int = 0,
494503
cast_as_fprop_dtype: bool = True,
495504
fprop_dtype: DType = jnp.bfloat16,
496505
use_scale: bool = True,
506+
shard_mode: ShardMode = ShardMode.AUTO,
497507
# Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen.
498508
# TODO: Remove when bridge no longer needed
499509
rngs: nnx.Rngs = None,
@@ -517,6 +527,8 @@ def __init__(
517527
embedding_dims=embedding_dims,
518528
cast_as_fprop_dtype=cast_as_fprop_dtype,
519529
fprop_dtype=fprop_dtype,
530+
mesh=mesh,
531+
shard_mode=shard_mode,
520532
rngs=rngs,
521533
)
522534

@@ -625,6 +637,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
625637
def yarn_rotary_embedding_as_linen(
626638
*,
627639
embedding_dims: int,
640+
mesh: Mesh,
628641
max_position_embeddings: int = 4096 * 4,
629642
original_max_position_embeddings: int = 4096,
630643
beta_fast: float = 32,
@@ -637,6 +650,7 @@ def yarn_rotary_embedding_as_linen(
637650
interleave: bool = True,
638651
truncate: bool = True,
639652
attention_scaling: bool = False,
653+
shard_mode: ShardMode = ShardMode.AUTO,
640654
):
641655
"""Initializes the YarnRotaryEmbedding module and returns it as a Linen module.
642656
@@ -655,6 +669,7 @@ def yarn_rotary_embedding_as_linen(
655669
return nnx_wrappers.to_linen(
656670
YarnRotaryEmbedding,
657671
embedding_dims=embedding_dims,
672+
mesh=mesh,
658673
max_position_embeddings=max_position_embeddings,
659674
original_max_position_embeddings=original_max_position_embeddings,
660675
beta_fast=beta_fast,
@@ -668,6 +683,7 @@ def yarn_rotary_embedding_as_linen(
668683
interleave=interleave,
669684
truncate=truncate,
670685
attention_scaling=attention_scaling,
686+
shard_mode=shard_mode,
671687
)
672688

673689

@@ -697,6 +713,7 @@ class YarnRotaryEmbedding(nnx.Module):
697713
def __init__(
698714
self,
699715
embedding_dims: int,
716+
mesh: Mesh,
700717
max_position_embeddings: int = 4096 * 4,
701718
original_max_position_embeddings: int = 4096,
702719
beta_fast: float = 32,
@@ -705,6 +722,7 @@ def __init__(
705722
rope_factor: float = 40,
706723
cast_as_fprop_dtype: bool = True,
707724
fprop_dtype: DType = jnp.bfloat16,
725+
shard_mode: ShardMode = ShardMode.AUTO,
708726
interleave=True,
709727
truncate=True,
710728
attention_scaling=False,
@@ -724,6 +742,13 @@ def __init__(
724742
self.fprop_dtype = fprop_dtype
725743
self.interleave = interleave
726744
self.truncate = truncate
745+
self.mesh = mesh
746+
self.shard_mode = shard_mode
747+
self.freqs_sharding = (
748+
NamedSharding(mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "q_heads")))
749+
if shard_mode == ShardMode.EXPLICIT
750+
else None
751+
)
727752
self.attention_scaling = attention_scaling
728753

729754
if self.embedding_dims % 2:
@@ -829,7 +854,8 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
829854
# Lookup the precomputed frequencies using the position indices.
830855
# self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
831856
# After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads.
832-
freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
857+
# freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
858+
freqs = self.freqs_cis.at[position].get(out_sharding=self.freqs_sharding)
833859
freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim]
834860

835861
if self.interleave:
@@ -846,7 +872,14 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
846872

847873
inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
848874
# Apply the rotary transformation via complex multiplication.
849-
rotated = inputs_complex * freqs # shape: [B, S, N, half_dim]
875+
rotated_sharding = (
876+
NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None)))
877+
if self.shard_mode == ShardMode.EXPLICIT
878+
else None
879+
)
880+
rotated = jnp.einsum(
881+
"ijkl, ijml->ijkl", inputs_complex, freqs, out_sharding=rotated_sharding
882+
) # shape: [B, S, N, half_dim]
850883
# Convert the complex result back to a real tensor.
851884
# Split the complex number into its real and imaginary parts.
852885
# [real1, real2, ..., img1, img2, ...]

0 commit comments

Comments
 (0)