3333from MaxText import max_logging
3434from MaxText import max_utils
3535from MaxText .common_types import ShardMode
36- from MaxText .sharding import maybe_shard_with_logical
36+ from MaxText .sharding import maybe_shard_with_logical , create_sharding
3737from MaxText .kernels import megablox as mblx
3838from MaxText .sharding import logical_to_mesh_axes
3939from MaxText .layers import attentions , linears , nnx_wrappers , quantizations
@@ -264,9 +264,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
264264
265265 # [B, S, E] -> [B, S, num_exp]
266266 output_sharding = (
267- NamedSharding (
268- self .mesh , nn .logical_to_mesh_axes (("activation_batch_no_exp" , "activation_length_no_exp" , "activation_exp" ))
269- )
267+ create_sharding (self .mesh , ("activation_batch_no_exp" , "activation_length_no_exp" , "activation_exp" ))
270268 if self .shard_mode == ShardMode .EXPLICIT
271269 else None
272270 )
@@ -505,7 +503,7 @@ def _get_logical_names(self, model_mode):
505503 def _create_sharding (self , axis_names ):
506504 """Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
507505 if self .config .shard_mode == ShardMode .EXPLICIT :
508- return NamedSharding (self .mesh , nn . logical_to_mesh_axes ( axis_names ) )
506+ return create_sharding (self .mesh , axis_names )
509507 return None
510508
511509 def setup_sharding (self , model_mode ):
@@ -1015,15 +1013,15 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
10151013 output = output [: hs_shape [0 ]]
10161014 return output
10171015
1018- input_partition_pspec = nn . logical_to_mesh_axes (self .logical_names .inputs )
1019- w0_bias_pspec = nn . logical_to_mesh_axes (self .logical_names .wi_bias )
1020- w1_bias_pspec = nn . logical_to_mesh_axes (self .logical_names .wi_bias )
1021- wo_bias_pspec = nn . logical_to_mesh_axes (self .logical_names .wo_bias )
1022- gate_logits_pspec = nn . logical_to_mesh_axes (self .logical_names .gate )
1023- pre_bias_logits_pspec = nn . logical_to_mesh_axes (self .logical_names .pre_bias )
1024- w0_pspec = nn . logical_to_mesh_axes (self .logical_names .wi_kernel_sp )
1025- w1_pspec = nn . logical_to_mesh_axes (self .logical_names .wi_kernel_sp )
1026- wo_pspec = nn . logical_to_mesh_axes (self .logical_names .wo_kernel_sp )
1016+ input_partition_pspec = logical_to_mesh_axes (self .logical_names .inputs , self . mesh )
1017+ w0_bias_pspec = logical_to_mesh_axes (self .logical_names .wi_bias , self . mesh )
1018+ w1_bias_pspec = logical_to_mesh_axes (self .logical_names .wi_bias , self . mesh )
1019+ wo_bias_pspec = logical_to_mesh_axes (self .logical_names .wo_bias , self . mesh )
1020+ gate_logits_pspec = logical_to_mesh_axes (self .logical_names .gate , self . mesh )
1021+ pre_bias_logits_pspec = logical_to_mesh_axes (self .logical_names .pre_bias , self . mesh )
1022+ w0_pspec = logical_to_mesh_axes (self .logical_names .wi_kernel_sp , self . mesh )
1023+ w1_pspec = logical_to_mesh_axes (self .logical_names .wi_kernel_sp , self . mesh )
1024+ wo_pspec = logical_to_mesh_axes (self .logical_names .wo_kernel_sp , self . mesh )
10271025
10281026 if isinstance (w0_kernel , aqt .QTensor ):
10291027 w0_pspec = aqt .partition_spec (w0_pspec , (1 ,), w0_kernel .dtype , use_bias = False )
@@ -1047,8 +1045,8 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
10471045 wo_bias_pspec ,
10481046 None ,
10491047 ),
1050- out_specs = (nn . logical_to_mesh_axes (self .logical_names .out )),
1051- check_vma = False ,
1048+ out_specs = (logical_to_mesh_axes (self .logical_names .out , self . mesh )),
1049+ check_vma = True ,
10521050 )
10531051 def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
10541052 batch_size , sequence_length , _ = x .shape
0 commit comments