Skip to content

Commit fe22e79

Browse files
committed
moe check_vma true
1 parent adf511a commit fe22e79

File tree

11 files changed

+39
-36
lines changed

11 files changed

+39
-36
lines changed

src/MaxText/configs/models/deepseek3-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ num_experts_per_tok: 8
3333
shared_experts: 1
3434
routed_scaling_factor: 2.5
3535
routed_score_func: "sigmoid"
36-
routed_bias: True
36+
routed_bias: False
3737
decoder_block: "deepseek"
3838
# MLA
3939
attention_type: "mla"

src/MaxText/configs/models/deepseek3-tiny.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ num_experts_per_tok: 8
3131
shared_experts: 1
3232
routed_scaling_factor: 2.5
3333
routed_score_func: "sigmoid"
34-
routed_bias: True
34+
routed_bias: False
3535
decoder_block: "deepseek"
3636
# MLA
3737
attention_type: "mla"

src/MaxText/kernels/megablox/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
522522
}
523523
call_gmm = qpl.pallas_call(
524524
kernel,
525-
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
525+
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, vma=set(["fsdp"])),
526526
grid_spec=pltpu.PrefetchScalarGridSpec(
527527
num_scalar_prefetch=2,
528528
in_specs=[
@@ -775,7 +775,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
775775
}
776776
call_gmm = qpl.pallas_call(
777777
kernel,
778-
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
778+
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, vma=set()),
779779
grid_spec=pltpu.PrefetchScalarGridSpec(
780780
num_scalar_prefetch=2,
781781
in_specs=[

src/MaxText/layers/attention_mla.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jax.numpy as jnp
2525

2626
from flax import nnx
27-
from flax import linen as nn
2827

2928
from MaxText.common_types import (
3029
Array,
@@ -65,7 +64,7 @@
6564
from MaxText.layers.linears import DenseGeneral
6665
from MaxText.layers.normalizations import RMSNorm
6766
from MaxText.layers.quantizations import AqtQuantization as Quant
68-
from MaxText.sharding import maybe_shard_with_logical
67+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
6968

7069

7170
@dataclasses.dataclass(frozen=True)
@@ -314,7 +313,7 @@ def __init__(
314313
def _create_sharding(self, axis_names):
315314
"""Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
316315
if self.config.shard_mode == ShardMode.EXPLICIT:
317-
return NamedSharding(self.mesh, nn.logical_to_mesh_axes(axis_names))
316+
return create_sharding(self.mesh, axis_names)
318317
return None
319318

320319
def _get_logical_names(self, model_mode):

src/MaxText/layers/attentions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
474474
def _create_sharding(self, axis_names):
475475
"""Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
476476
if self.config.shard_mode == ShardMode.EXPLICIT:
477-
return NamedSharding(self.mesh, nn.logical_to_mesh_axes(axis_names))
477+
return create_sharding(self.mesh, axis_names)
478478
return None
479479

480480
def _get_logical_names(self, model_mode):

src/MaxText/layers/deepseek.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from functools import partial
2020

2121
from jax.ad_checkpoint import checkpoint_name
22-
from jax.sharding import Mesh, NamedSharding
22+
from jax.sharding import Mesh
2323
import jax.numpy as jnp
2424

2525
from flax import linen as nn
@@ -33,7 +33,7 @@
3333
from MaxText.layers import quantizations
3434
from MaxText.layers.quantizations import AqtQuantization as Quant
3535
from MaxText.inference import page_manager
36-
from MaxText.sharding import maybe_shard_with_logical
36+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
3737
from MaxText.common_types import MODEL_MODE_PREFILL
3838

3939
# -----------------------------------------
@@ -74,7 +74,7 @@ def self_attention_with_norm(
7474
mesh=mesh,
7575
shard_mode=cfg.shard_mode,
7676
)
77-
lnx_out_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axis_names))
77+
lnx_out_sharding = create_sharding(mesh, logical_axis_names)
7878

7979
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
8080

@@ -184,8 +184,8 @@ def __call__(
184184
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
185185
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
186186
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
187-
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
188-
mlp_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
187+
lnx_out_sharding = create_sharding(self.mesh, logical_axis_names)
188+
mlp_intermediate_sharding = create_sharding(self.mesh, mlp_logical_axis_names)
189189
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
190190
inputs = checkpoint_name(inputs, "decoder_layer_input")
191191

@@ -263,8 +263,8 @@ def __call__(
263263
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
264264
mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp")
265265
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
266-
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
267-
lnx_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
266+
lnx_out_sharding = create_sharding(self.mesh, logical_axis_names)
267+
lnx_intermediate_sharding = create_sharding(self.mesh, mlp_logical_axis_names)
268268
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
269269
inputs = checkpoint_name(inputs, "decoder_layer_input")
270270

src/MaxText/layers/embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from MaxText import max_logging
2828
from MaxText import max_utils
29-
from MaxText.sharding import logical_to_mesh_axes
29+
from MaxText.sharding import logical_to_mesh_axes, create_sharding
3030
from MaxText.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
3131
from MaxText.layers import nnx_wrappers
3232
from MaxText.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned
@@ -745,7 +745,7 @@ def __init__(
745745
self.mesh = mesh
746746
self.shard_mode = shard_mode
747747
self.freqs_sharding = (
748-
NamedSharding(mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "q_heads")))
748+
create_sharding(mesh, ("activation_batch", "activation_length_no_exp", "q_heads"))
749749
if shard_mode == ShardMode.EXPLICIT
750750
else None
751751
)
@@ -873,7 +873,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
873873
inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
874874
# Apply the rotary transformation via complex multiplication.
875875
rotated_sharding = (
876-
NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None)))
876+
create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", None, None))
877877
if self.shard_mode == ShardMode.EXPLICIT
878878
else None
879879
)

src/MaxText/layers/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139
def _create_sharding(self, axis_names):
140140
"""Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
141141
if self.config.shard_mode == ShardMode.EXPLICIT:
142-
return NamedSharding(self.mesh, nn.logical_to_mesh_axes(axis_names))
142+
return create_sharding(self.mesh, axis_names)
143143
return None
144144

145145
def _get_logical_names(self, model_mode):

src/MaxText/layers/moe.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from MaxText import max_logging
3434
from MaxText import max_utils
3535
from MaxText.common_types import ShardMode
36-
from MaxText.sharding import maybe_shard_with_logical
36+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
3737
from MaxText.kernels import megablox as mblx
3838
from MaxText.sharding import logical_to_mesh_axes
3939
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
@@ -264,9 +264,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
264264

265265
# [B, S, E] -> [B, S, num_exp]
266266
output_sharding = (
267-
NamedSharding(
268-
self.mesh, nn.logical_to_mesh_axes(("activation_batch_no_exp", "activation_length_no_exp", "activation_exp"))
269-
)
267+
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", "activation_exp"))
270268
if self.shard_mode == ShardMode.EXPLICIT
271269
else None
272270
)
@@ -505,7 +503,7 @@ def _get_logical_names(self, model_mode):
505503
def _create_sharding(self, axis_names):
506504
"""Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
507505
if self.config.shard_mode == ShardMode.EXPLICIT:
508-
return NamedSharding(self.mesh, nn.logical_to_mesh_axes(axis_names))
506+
return create_sharding(self.mesh, axis_names)
509507
return None
510508

511509
def setup_sharding(self, model_mode):
@@ -1015,15 +1013,15 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
10151013
output = output[: hs_shape[0]]
10161014
return output
10171015

1018-
input_partition_pspec = nn.logical_to_mesh_axes(self.logical_names.inputs)
1019-
w0_bias_pspec = nn.logical_to_mesh_axes(self.logical_names.wi_bias)
1020-
w1_bias_pspec = nn.logical_to_mesh_axes(self.logical_names.wi_bias)
1021-
wo_bias_pspec = nn.logical_to_mesh_axes(self.logical_names.wo_bias)
1022-
gate_logits_pspec = nn.logical_to_mesh_axes(self.logical_names.gate)
1023-
pre_bias_logits_pspec = nn.logical_to_mesh_axes(self.logical_names.pre_bias)
1024-
w0_pspec = nn.logical_to_mesh_axes(self.logical_names.wi_kernel_sp)
1025-
w1_pspec = nn.logical_to_mesh_axes(self.logical_names.wi_kernel_sp)
1026-
wo_pspec = nn.logical_to_mesh_axes(self.logical_names.wo_kernel_sp)
1016+
input_partition_pspec = logical_to_mesh_axes(self.logical_names.inputs, self.mesh)
1017+
w0_bias_pspec = logical_to_mesh_axes(self.logical_names.wi_bias, self.mesh)
1018+
w1_bias_pspec = logical_to_mesh_axes(self.logical_names.wi_bias, self.mesh)
1019+
wo_bias_pspec = logical_to_mesh_axes(self.logical_names.wo_bias, self.mesh)
1020+
gate_logits_pspec = logical_to_mesh_axes(self.logical_names.gate, self.mesh)
1021+
pre_bias_logits_pspec = logical_to_mesh_axes(self.logical_names.pre_bias, self.mesh)
1022+
w0_pspec = logical_to_mesh_axes(self.logical_names.wi_kernel_sp, self.mesh)
1023+
w1_pspec = logical_to_mesh_axes(self.logical_names.wi_kernel_sp, self.mesh)
1024+
wo_pspec = logical_to_mesh_axes(self.logical_names.wo_kernel_sp, self.mesh)
10271025

10281026
if isinstance(w0_kernel, aqt.QTensor):
10291027
w0_pspec = aqt.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False)
@@ -1047,8 +1045,8 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
10471045
wo_bias_pspec,
10481046
None,
10491047
),
1050-
out_specs=(nn.logical_to_mesh_axes(self.logical_names.out)),
1051-
check_vma=False,
1048+
out_specs=(logical_to_mesh_axes(self.logical_names.out, self.mesh)),
1049+
check_vma=True,
10521050
)
10531051
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
10541052
batch_size, sequence_length, _ = x.shape

src/MaxText/sharding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def maybe_shard_with_logical(inputs, logical_axes, mesh, shard_mode):
5151
"""
5252
A wrapper of maybe_shard_with_name when logical axes are inputs
5353
"""
54+
if inputs is None:
55+
return None
5456
named_sharding = create_sharding(mesh, logical_axes)
5557
return maybe_shard_with_name(inputs, named_sharding, shard_mode)
5658

0 commit comments

Comments
 (0)