3838 EMBED ,
3939 EP_AS_CONTEXT ,
4040 HEAD ,
41+ Q_LORA_UP_PROJ ,
4142 KV_BATCH ,
4243 KV_BATCH_NO_EXP ,
4344 KV_HEAD ,
4445 KV_HEAD_DIM ,
46+ KV_LORA_UP_PROJ ,
4547 LENGTH ,
4648 LENGTH_NO_EXP ,
4749 MODEL_MODE_PREFILL ,
@@ -389,6 +391,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
389391 weight_dtype = self .weight_dtype ,
390392 quant = self .quant ,
391393 matmul_precision = self .config .matmul_precision ,
394+ shard_mode = self .config .shard_mode ,
392395 rngs = self .rngs ,
393396 )
394397 else :
@@ -403,6 +406,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
403406 weight_dtype = self .weight_dtype ,
404407 quant = self .quant ,
405408 matmul_precision = self .config .matmul_precision ,
409+ shard_mode = self .config .shard_mode ,
406410 rngs = self .rngs ,
407411 )
408412 self .q_norm = RMSNorm (
@@ -423,6 +427,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
423427 weight_dtype = self .weight_dtype ,
424428 quant = self .quant ,
425429 matmul_precision = self .config .matmul_precision ,
430+ shard_mode = self .config .shard_mode ,
426431 rngs = self .rngs ,
427432 )
428433
@@ -437,6 +442,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
437442 weight_dtype = self .weight_dtype ,
438443 quant = self .quant ,
439444 matmul_precision = self .config .matmul_precision ,
445+ shard_mode = self .config .shard_mode ,
440446 rngs = self .rngs ,
441447 )
442448 self .kv_norm = RMSNorm (
@@ -460,6 +466,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
460466 weight_dtype = self .weight_dtype ,
461467 quant = self .quant ,
462468 matmul_precision = self .config .matmul_precision ,
469+ shard_mode = self .config .shard_mode ,
463470 rngs = self .rngs ,
464471 )
465472
@@ -498,6 +505,18 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
498505
499506 def mla_query_projection (self , inputs_q : Array , inputs_positions : Array , model_mode ) -> Array :
500507 """Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
508+ # specify query logical name
509+ if model_mode == MODEL_MODE_PREFILL :
510+ query_logical_name = self .prefill_query_axis_names
511+ wqa_logical_name = (PREFILL_KV_BATCH , PREFILL_LENGTH , Q_LORA_UP_PROJ )
512+ elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
513+ query_logical_name = self .ep_query_axis_names
514+ wqa_logical_name = (KV_BATCH_NO_EXP , LENGTH , Q_LORA_UP_PROJ )
515+ else :
516+ query_logical_name = self .query_axis_names
517+ wqa_logical_name = (KV_BATCH , LENGTH_NO_EXP , Q_LORA_UP_PROJ )
518+ query_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (query_logical_name ))
519+ wqa_out_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (wqa_logical_name ))
501520 # Set softmax scaling.
502521 self .qk_head_dim = self .qk_nope_head_dim + self .qk_rope_head_dim
503522 self .softmax_scale = self .qk_head_dim ** - 0.5
@@ -506,47 +525,49 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m
506525 self .softmax_scale = self .softmax_scale * mscale * mscale
507526
508527 if self .q_lora_rank == 0 :
509- q = self .query (inputs_q )
528+ q = self .query (inputs_q , out_sharding = query_sharding )
510529 else :
511530 # LoRA path
512- low_rank_q = self .wq_a (inputs_q ) # [B, L, q_lora_rank]
531+ low_rank_q = self .wq_a (inputs_q , out_sharding = wqa_out_sharding ) # [B, L, q_lora_rank]
513532 low_rank_q = self .q_norm (low_rank_q ) # RMSNorm on low rank
514- q = self .wq_b (low_rank_q ) # [B, L, n_heads * qk_head_dim]
533+ q = self .wq_b (low_rank_q , out_sharding = query_sharding ) # [B, L, n_heads * qk_head_dim]
515534
516535 # Split into non-positional and rotary parts.
517536 q_nope , q_pe = jnp .split (q , [self .qk_nope_head_dim ], axis = - 1 )
537+ q_nope = self ._maybe_shard_with_logical (q_nope , query_logical_name )
518538 q_pe = self .apply_rotary_embedding (q_pe , inputs_positions = inputs_positions )
539+ q_pe = self ._maybe_shard_with_logical (q_pe , query_logical_name )
519540 # Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
520541 # DeepSeek v3 was doing it in attention score computation.
521542 query = jnp .concatenate ([q_nope , q_pe ], axis = - 1 ) * self .softmax_scale
543+ query = self ._maybe_shard_with_logical (query , query_logical_name )
544+ return query
522545
546+ def mla_get_key_value (self , low_rank_main , key_rope , model_mode ):
547+ """get (key,value) pair from mla"""
523548 if model_mode == MODEL_MODE_PREFILL :
524- query = nn .with_logical_constraint (query , self .prefill_query_axis_names )
549+ key_logical_name = self .prefill_key_axis_names
550+ value_logical_name = self .prefill_value_axis_names
525551 elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
526- query = nn .with_logical_constraint (query , self .ep_query_axis_names )
552+ key_logical_name = self .ep_key_axis_names
553+ value_logical_name = self .ep_value_axis_names
527554 else :
528- query = nn . with_logical_constraint ( query , self .query_axis_names )
529- return query
555+ key_logical_name = self .key_axis_names
556+ value_logical_name = self . value_axis_names
530557
531- def mla_get_key_value (self , low_rank_main , key_rope , model_mode ):
532- """get (key,value) pair from mla"""
533- kv_out = self .wkv_b (low_rank_main )
558+ wkva_out_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (key_logical_name ))
559+ kv_out = self .wkv_b (low_rank_main , out_sharding = wkva_out_sharding )
534560
535561 # Split kv_out into key_nope and value parts.
536562 key_nope , value = jnp .split (kv_out , [self .qk_nope_head_dim ], axis = - 1 )
537563 key_rope = jnp .broadcast_to (key_rope , (key_nope .shape [0 ], key_nope .shape [1 ], self .num_query_heads , key_rope .shape [3 ]))
564+ key_nope = self ._maybe_shard_with_logical (key_nope , key_logical_name )
565+ key_rope = self ._maybe_shard_with_logical (key_rope , key_logical_name )
538566
539567 key = jnp .concatenate ([key_nope , key_rope ], axis = - 1 )
540568
541- if model_mode == MODEL_MODE_PREFILL :
542- key = nn .with_logical_constraint (key , self .prefill_key_axis_names )
543- value = nn .with_logical_constraint (value , self .prefill_value_axis_names )
544- elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
545- key = nn .with_logical_constraint (key , self .ep_key_axis_names )
546- value = nn .with_logical_constraint (value , self .ep_value_axis_names )
547- else :
548- key = nn .with_logical_constraint (key , self .key_axis_names )
549- value = nn .with_logical_constraint (value , self .value_axis_names )
569+ key = self ._maybe_shard_with_logical (key , key_logical_name )
570+ value = self ._maybe_shard_with_logical (value , value_logical_name )
550571 return key , value
551572
552573 def init_mla_kv_caches (self , inputs_kv_shape : Tuple ):
@@ -637,7 +658,14 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod
637658
638659 def mla_kv_projection (self , inputs : Array , inputs_positions : Array , decoder_segment_ids , model_mode , previous_chunk ):
639660 """MLA key/value projection with integrated rotary embedding."""
640- low_rank = self .wkv_a (inputs )
661+ if model_mode == MODEL_MODE_PREFILL :
662+ wka_logical_name = (PREFILL_KV_BATCH , PREFILL_LENGTH , KV_LORA_UP_PROJ )
663+ elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
664+ wka_logical_name = (KV_BATCH_NO_EXP , LENGTH , KV_LORA_UP_PROJ )
665+ else :
666+ wka_logical_name = (KV_BATCH , LENGTH_NO_EXP , KV_LORA_UP_PROJ )
667+ wkva_out_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (wka_logical_name ))
668+ low_rank = self .wkv_a (inputs , out_sharding = wkva_out_sharding )
641669 low_rank_main , low_rank_rope = jnp .split (low_rank , [self .kv_lora_rank ], axis = - 1 )
642670 low_rank_main = self .kv_norm (low_rank_main )
643671
@@ -696,14 +724,17 @@ def __call__(
696724 MLA-attended outputs.
697725 """
698726 if model_mode == MODEL_MODE_PREFILL :
699- inputs_q = nn .with_logical_constraint (inputs_q , self .prefill_input_axis_names )
700- inputs_kv = nn .with_logical_constraint (inputs_kv , self .prefill_input_axis_names )
727+ inputs_q = self ._maybe_shard_with_logical (inputs_q , self .prefill_input_axis_names )
728+ inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .prefill_input_axis_names )
729+ out_logical_name = (PREFILL_KV_BATCH , PREFILL_LENGTH , HEAD , D_KV )
701730 elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
702- inputs_q = nn .with_logical_constraint (inputs_q , self .ep_input_axis_names )
703- inputs_kv = nn .with_logical_constraint (inputs_kv , self .ep_input_axis_names )
731+ inputs_q = self ._maybe_shard_with_logical (inputs_q , self .ep_input_axis_names )
732+ inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .ep_input_axis_names )
733+ out_logical_name = (BATCH_NO_EXP , LENGTH , HEAD , D_KV )
704734 else :
705- inputs_q = nn .with_logical_constraint (inputs_q , self .input_axis_names )
706- inputs_kv = nn .with_logical_constraint (inputs_kv , self .input_axis_names )
735+ inputs_q = self ._maybe_shard_with_logical (inputs_q , self .input_axis_names )
736+ inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .input_axis_names )
737+ out_logical_name = (BATCH , LENGTH_NO_EXP , HEAD , D_KV )
707738
708739 query = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
709740 key , value , cached_values = self .mla_kv_projection (
@@ -724,10 +755,11 @@ def __call__(
724755 out = self .attention_op (query , key , value , decoder_segment_ids , model_mode , cached_values )
725756
726757 if model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
727- out = nn . with_logical_constraint (out , self .ep_out_axis_names )
758+ out = self . _maybe_shard_with_logical (out , self .ep_out_axis_names )
728759 else :
729- out = nn . with_logical_constraint (out , self .out_axis_names )
760+ out = self . _maybe_shard_with_logical (out , self .out_axis_names )
730761
731- out = self .out_projection (out )
762+ out_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (out_logical_name ))
763+ out = self .out_projection (out , out_sharding = out_sharding )
732764 out = checkpoint_name (out , "out_proj" )
733765 return out , kv_cache
0 commit comments