Skip to content

Commit 3dc9988

Browse files
committed
deepseek explicit split
1 parent cb47516 commit 3dc9988

File tree

14 files changed

+421
-109
lines changed

14 files changed

+421
-109
lines changed

src/MaxText/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
PREFILL_LENGTH = "prefill_activation_length"
3838
Q_LENGTH = "activation_q_length"
3939
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
40+
Q_LORA_UP_PROJ = "q_lora_up_proj"
4041
KV_LENGTH = "activation_kv_length"
42+
KV_LORA_UP_PROJ = "kv_lora_up_proj"
4143
EMBED = "activation_embed"
4244
HEAD = "activation_heads"
4345
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/MaxText/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
19201920
if self.packing:
19211921
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
19221922
if self.shard_mode == ShardMode.EXPLICIT:
1923-
supported_decoders = {"simple", "simple_mlp", "llama2"}
1923+
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
19241924
if self.decoder_block.value not in supported_decoders:
19251925
raise ValueError(
19261926
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "

src/MaxText/layers/attention_mla.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@
3838
EMBED,
3939
EP_AS_CONTEXT,
4040
HEAD,
41+
Q_LORA_UP_PROJ,
4142
KV_BATCH,
4243
KV_BATCH_NO_EXP,
4344
KV_HEAD,
4445
KV_HEAD_DIM,
46+
KV_LORA_UP_PROJ,
4547
LENGTH,
4648
LENGTH_NO_EXP,
4749
MODEL_MODE_PREFILL,
@@ -389,6 +391,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
389391
weight_dtype=self.weight_dtype,
390392
quant=self.quant,
391393
matmul_precision=self.config.matmul_precision,
394+
shard_mode=self.config.shard_mode,
392395
rngs=self.rngs,
393396
)
394397
else:
@@ -403,6 +406,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
403406
weight_dtype=self.weight_dtype,
404407
quant=self.quant,
405408
matmul_precision=self.config.matmul_precision,
409+
shard_mode=self.config.shard_mode,
406410
rngs=self.rngs,
407411
)
408412
self.q_norm = RMSNorm(
@@ -423,6 +427,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
423427
weight_dtype=self.weight_dtype,
424428
quant=self.quant,
425429
matmul_precision=self.config.matmul_precision,
430+
shard_mode=self.config.shard_mode,
426431
rngs=self.rngs,
427432
)
428433

@@ -437,6 +442,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
437442
weight_dtype=self.weight_dtype,
438443
quant=self.quant,
439444
matmul_precision=self.config.matmul_precision,
445+
shard_mode=self.config.shard_mode,
440446
rngs=self.rngs,
441447
)
442448
self.kv_norm = RMSNorm(
@@ -460,6 +466,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
460466
weight_dtype=self.weight_dtype,
461467
quant=self.quant,
462468
matmul_precision=self.config.matmul_precision,
469+
shard_mode=self.config.shard_mode,
463470
rngs=self.rngs,
464471
)
465472

@@ -498,6 +505,18 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
498505

499506
def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array:
500507
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
508+
# specify query logical name
509+
if model_mode == MODEL_MODE_PREFILL:
510+
query_logical_name = self.prefill_query_axis_names
511+
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
512+
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
513+
query_logical_name = self.ep_query_axis_names
514+
wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ)
515+
else:
516+
query_logical_name = self.query_axis_names
517+
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
518+
query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(query_logical_name))
519+
wqa_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wqa_logical_name))
501520
# Set softmax scaling.
502521
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
503522
self.softmax_scale = self.qk_head_dim**-0.5
@@ -506,47 +525,49 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m
506525
self.softmax_scale = self.softmax_scale * mscale * mscale
507526

508527
if self.q_lora_rank == 0:
509-
q = self.query(inputs_q)
528+
q = self.query(inputs_q, out_sharding=query_sharding)
510529
else:
511530
# LoRA path
512-
low_rank_q = self.wq_a(inputs_q) # [B, L, q_lora_rank]
531+
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
513532
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
514-
q = self.wq_b(low_rank_q) # [B, L, n_heads * qk_head_dim]
533+
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads * qk_head_dim]
515534

516535
# Split into non-positional and rotary parts.
517536
q_nope, q_pe = jnp.split(q, [self.qk_nope_head_dim], axis=-1)
537+
q_nope = self._maybe_shard_with_logical(q_nope, query_logical_name)
518538
q_pe = self.apply_rotary_embedding(q_pe, inputs_positions=inputs_positions)
539+
q_pe = self._maybe_shard_with_logical(q_pe, query_logical_name)
519540
# Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
520541
# DeepSeek v3 was doing it in attention score computation.
521542
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale
543+
query = self._maybe_shard_with_logical(query, query_logical_name)
544+
return query
522545

546+
def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
547+
"""get (key,value) pair from mla"""
523548
if model_mode == MODEL_MODE_PREFILL:
524-
query = nn.with_logical_constraint(query, self.prefill_query_axis_names)
549+
key_logical_name = self.prefill_key_axis_names
550+
value_logical_name = self.prefill_value_axis_names
525551
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
526-
query = nn.with_logical_constraint(query, self.ep_query_axis_names)
552+
key_logical_name = self.ep_key_axis_names
553+
value_logical_name = self.ep_value_axis_names
527554
else:
528-
query = nn.with_logical_constraint(query, self.query_axis_names)
529-
return query
555+
key_logical_name = self.key_axis_names
556+
value_logical_name = self.value_axis_names
530557

531-
def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
532-
"""get (key,value) pair from mla"""
533-
kv_out = self.wkv_b(low_rank_main)
558+
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(key_logical_name))
559+
kv_out = self.wkv_b(low_rank_main, out_sharding=wkva_out_sharding)
534560

535561
# Split kv_out into key_nope and value parts.
536562
key_nope, value = jnp.split(kv_out, [self.qk_nope_head_dim], axis=-1)
537563
key_rope = jnp.broadcast_to(key_rope, (key_nope.shape[0], key_nope.shape[1], self.num_query_heads, key_rope.shape[3]))
564+
key_nope = self._maybe_shard_with_logical(key_nope, key_logical_name)
565+
key_rope = self._maybe_shard_with_logical(key_rope, key_logical_name)
538566

539567
key = jnp.concatenate([key_nope, key_rope], axis=-1)
540568

541-
if model_mode == MODEL_MODE_PREFILL:
542-
key = nn.with_logical_constraint(key, self.prefill_key_axis_names)
543-
value = nn.with_logical_constraint(value, self.prefill_value_axis_names)
544-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
545-
key = nn.with_logical_constraint(key, self.ep_key_axis_names)
546-
value = nn.with_logical_constraint(value, self.ep_value_axis_names)
547-
else:
548-
key = nn.with_logical_constraint(key, self.key_axis_names)
549-
value = nn.with_logical_constraint(value, self.value_axis_names)
569+
key = self._maybe_shard_with_logical(key, key_logical_name)
570+
value = self._maybe_shard_with_logical(value, value_logical_name)
550571
return key, value
551572

552573
def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
@@ -637,7 +658,14 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod
637658

638659
def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segment_ids, model_mode, previous_chunk):
639660
"""MLA key/value projection with integrated rotary embedding."""
640-
low_rank = self.wkv_a(inputs)
661+
if model_mode == MODEL_MODE_PREFILL:
662+
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
663+
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
664+
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
665+
else:
666+
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
667+
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wka_logical_name))
668+
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
641669
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
642670
low_rank_main = self.kv_norm(low_rank_main)
643671

@@ -696,14 +724,17 @@ def __call__(
696724
MLA-attended outputs.
697725
"""
698726
if model_mode == MODEL_MODE_PREFILL:
699-
inputs_q = nn.with_logical_constraint(inputs_q, self.prefill_input_axis_names)
700-
inputs_kv = nn.with_logical_constraint(inputs_kv, self.prefill_input_axis_names)
727+
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
728+
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
729+
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
701730
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
702-
inputs_q = nn.with_logical_constraint(inputs_q, self.ep_input_axis_names)
703-
inputs_kv = nn.with_logical_constraint(inputs_kv, self.ep_input_axis_names)
731+
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
732+
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
733+
out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV)
704734
else:
705-
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
706-
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)
735+
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
736+
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
737+
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
707738

708739
query = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
709740
key, value, cached_values = self.mla_kv_projection(
@@ -724,10 +755,11 @@ def __call__(
724755
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
725756

726757
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
727-
out = nn.with_logical_constraint(out, self.ep_out_axis_names)
758+
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
728759
else:
729-
out = nn.with_logical_constraint(out, self.out_axis_names)
760+
out = self._maybe_shard_with_logical(out, self.out_axis_names)
730761

731-
out = self.out_projection(out)
762+
out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(out_logical_name))
763+
out = self.out_projection(out, out_sharding=out_sharding)
732764
out = checkpoint_name(out, "out_proj")
733765
return out, kv_cache

src/MaxText/layers/attentions.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -748,14 +748,17 @@ def init_rotary_embedding(self):
748748
rotary_embedding = LLaMARotaryEmbedding(
749749
min_timescale=self.config.rope_min_timescale,
750750
max_timescale=self.config.rope_max_timescale,
751+
mesh=self.mesh,
751752
embedding_dims=rope_embedding_dims,
752753
fprop_dtype=self.dtype,
753754
use_scale=rope_use_scale,
755+
shard_mode=self.config.shard_mode,
754756
rngs=self.rngs,
755757
)
756758
elif rope_type.startswith("yarn"):
757759
rotary_embedding = YarnRotaryEmbedding(
758760
max_position_embeddings=self.config.max_position_embeddings,
761+
mesh=self.mesh,
759762
original_max_position_embeddings=self.config.original_max_position_embeddings,
760763
beta_fast=self.config.beta_fast,
761764
beta_slow=self.config.beta_slow,
@@ -766,16 +769,19 @@ def init_rotary_embedding(self):
766769
interleave=self.config.rope_interleave,
767770
truncate=self.config.rope_truncate,
768771
attention_scaling=self.config.rope_attention_scaling,
772+
shard_mode=self.config.shard_mode,
769773
rngs=self.rngs,
770774
)
771775
elif self.is_qwen3_next:
772776
rotary_embedding = Qwen3NextRotaryEmbedding(
773777
min_timescale=self.config.rope_min_timescale,
774778
max_timescale=self.config.rope_max_timescale,
779+
mesh=self.mesh,
775780
embedding_dims=self.config.head_dim,
776781
partial_rotary_factor=self.config.partial_rotary_factor,
777782
cast_as_fprop_dtype=True,
778783
fprop_dtype=self.config.dtype,
784+
shard_mode=self.config.shard_mode,
779785
rngs=self.rngs,
780786
)
781787
else:
@@ -792,9 +798,11 @@ def init_rotary_embedding(self):
792798
rotary_embedding = RotaryEmbedding(
793799
min_timescale=self.config.rope_min_timescale,
794800
max_timescale=max_timescale,
801+
mesh=self.mesh,
795802
embedding_dims=rope_embedding_dims,
796803
fprop_dtype=self.dtype,
797804
rope_linear_scaling_factor=rope_linear_scaling_factor,
805+
shard_mode=self.config.shard_mode,
798806
rngs=self.rngs,
799807
)
800808
return rotary_embedding
@@ -985,28 +993,25 @@ def __call__(
985993
output of shape `[batch, length, q_features]`.
986994
"""
987995
if model_mode == MODEL_MODE_PREFILL:
988-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
989-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
996+
input_axis_names = self.prefill_input_axis_names
990997
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
991-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
992-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
998+
input_axis_names = self.ep_input_axis_names
993999
elif model_mode == MODEL_MODE_TRAIN:
994-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
995-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
1000+
input_axis_names = self.input_axis_names
9961001
else:
997-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.decode_input_axis_names)
998-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.decode_input_axis_names)
1002+
input_axis_names = self.decode_input_axis_names
1003+
1004+
inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names)
1005+
inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names)
1006+
qkv_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(input_axis_names))
9991007

10001008
# apply projection.
10011009
if self.config.fused_qkv:
10021010
query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj")
10031011
else:
1004-
query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.query_axis_names))
1005-
query = self.query_projection(inputs_q, out_sharding=query_sharding)
1006-
key_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.key_axis_names))
1007-
key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=key_sharding)
1008-
value_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.value_axis_names))
1009-
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=value_sharding)
1012+
query = self.query_projection(inputs_q, out_sharding=qkv_sharding)
1013+
key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=qkv_sharding)
1014+
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding)
10101015

10111016
gate = None
10121017
if self.is_qwen3_next:

0 commit comments

Comments
 (0)