@@ -275,9 +275,11 @@ def __init__(
275275 self ,
276276 min_timescale : int ,
277277 max_timescale : int ,
278+ mesh : Mesh ,
278279 embedding_dims : int = 0 ,
279280 cast_as_fprop_dtype : bool = True ,
280281 fprop_dtype : DType = jnp .bfloat16 ,
282+ shard_mode : ShardMode = ShardMode .AUTO ,
281283 # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
282284 # TODO: Remove when bridge no longer needed
283285 rope_linear_scaling_factor : float = 1.0 ,
@@ -297,9 +299,11 @@ def __init__(
297299 """
298300 self .min_timescale = min_timescale
299301 self .max_timescale = max_timescale
302+ self .mesh = mesh
300303 self .embedding_dims = embedding_dims
301304 self .cast_as_fprop_dtype = cast_as_fprop_dtype
302305 self .fprop_dtype = fprop_dtype
306+ self .shard_mode = shard_mode
303307 self .rope_linear_scaling_factor = rope_linear_scaling_factor
304308
305309 if self .embedding_dims % 2 :
@@ -396,6 +400,7 @@ def qwen3_next_rotary_embedding_as_linen(
396400 * ,
397401 min_timescale : int ,
398402 max_timescale : int ,
403+ mesh : Mesh ,
399404 embedding_dims : int = 0 ,
400405 partial_rotary_factor : float = 0.25 ,
401406 cast_as_fprop_dtype : bool = True ,
@@ -419,6 +424,7 @@ def qwen3_next_rotary_embedding_as_linen(
419424 Qwen3NextRotaryEmbedding ,
420425 min_timescale = min_timescale ,
421426 max_timescale = max_timescale ,
427+ mesh = mesh ,
422428 embedding_dims = embedding_dims ,
423429 partial_rotary_factor = partial_rotary_factor ,
424430 cast_as_fprop_dtype = cast_as_fprop_dtype ,
@@ -435,6 +441,7 @@ def __init__(
435441 self ,
436442 min_timescale : int ,
437443 max_timescale : int ,
444+ mesh : Mesh ,
438445 embedding_dims : int = 0 ,
439446 cast_as_fprop_dtype : bool = True ,
440447 fprop_dtype : DType = jnp .bfloat16 ,
@@ -459,6 +466,7 @@ def __init__(
459466 super ().__init__ (
460467 min_timescale = min_timescale ,
461468 max_timescale = max_timescale ,
469+ mesh = mesh ,
462470 embedding_dims = self .rotary_dim ,
463471 cast_as_fprop_dtype = cast_as_fprop_dtype ,
464472 fprop_dtype = fprop_dtype ,
@@ -490,10 +498,12 @@ def __init__(
490498 self ,
491499 min_timescale : int ,
492500 max_timescale : int ,
501+ mesh : Mesh ,
493502 embedding_dims : int = 0 ,
494503 cast_as_fprop_dtype : bool = True ,
495504 fprop_dtype : DType = jnp .bfloat16 ,
496505 use_scale : bool = True ,
506+ shard_mode : ShardMode = ShardMode .AUTO ,
497507 # Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen.
498508 # TODO: Remove when bridge no longer needed
499509 rngs : nnx .Rngs = None ,
@@ -517,6 +527,8 @@ def __init__(
517527 embedding_dims = embedding_dims ,
518528 cast_as_fprop_dtype = cast_as_fprop_dtype ,
519529 fprop_dtype = fprop_dtype ,
530+ mesh = mesh ,
531+ shard_mode = shard_mode ,
520532 rngs = rngs ,
521533 )
522534
@@ -625,6 +637,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
625637def yarn_rotary_embedding_as_linen (
626638 * ,
627639 embedding_dims : int ,
640+ mesh : Mesh ,
628641 max_position_embeddings : int = 4096 * 4 ,
629642 original_max_position_embeddings : int = 4096 ,
630643 beta_fast : float = 32 ,
@@ -637,6 +650,7 @@ def yarn_rotary_embedding_as_linen(
637650 interleave : bool = True ,
638651 truncate : bool = True ,
639652 attention_scaling : bool = False ,
653+ shard_mode : ShardMode = ShardMode .AUTO ,
640654):
641655 """Initializes the YarnRotaryEmbedding module and returns it as a Linen module.
642656
@@ -655,6 +669,7 @@ def yarn_rotary_embedding_as_linen(
655669 return nnx_wrappers .to_linen (
656670 YarnRotaryEmbedding ,
657671 embedding_dims = embedding_dims ,
672+ mesh = mesh ,
658673 max_position_embeddings = max_position_embeddings ,
659674 original_max_position_embeddings = original_max_position_embeddings ,
660675 beta_fast = beta_fast ,
@@ -668,6 +683,7 @@ def yarn_rotary_embedding_as_linen(
668683 interleave = interleave ,
669684 truncate = truncate ,
670685 attention_scaling = attention_scaling ,
686+ shard_mode = shard_mode ,
671687 )
672688
673689
@@ -697,6 +713,7 @@ class YarnRotaryEmbedding(nnx.Module):
697713 def __init__ (
698714 self ,
699715 embedding_dims : int ,
716+ mesh : Mesh ,
700717 max_position_embeddings : int = 4096 * 4 ,
701718 original_max_position_embeddings : int = 4096 ,
702719 beta_fast : float = 32 ,
@@ -705,6 +722,7 @@ def __init__(
705722 rope_factor : float = 40 ,
706723 cast_as_fprop_dtype : bool = True ,
707724 fprop_dtype : DType = jnp .bfloat16 ,
725+ shard_mode : ShardMode = ShardMode .AUTO ,
708726 interleave = True ,
709727 truncate = True ,
710728 attention_scaling = False ,
@@ -724,6 +742,13 @@ def __init__(
724742 self .fprop_dtype = fprop_dtype
725743 self .interleave = interleave
726744 self .truncate = truncate
745+ self .mesh = mesh
746+ self .shard_mode = shard_mode
747+ self .freqs_sharding = (
748+ NamedSharding (mesh , nn .logical_to_mesh_axes (("activation_batch" , "activation_length_no_exp" , "q_heads" )))
749+ if shard_mode == ShardMode .EXPLICIT
750+ else None
751+ )
727752 self .attention_scaling = attention_scaling
728753
729754 if self .embedding_dims % 2 :
@@ -829,7 +854,8 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
829854 # Lookup the precomputed frequencies using the position indices.
830855 # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
831856 # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads.
832- freqs = jnp .take (self .freqs_cis , position , axis = 0 ) # shape: [B, S, half_dim]
857+ # freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
858+ freqs = self .freqs_cis .at [position ].get (out_sharding = self .freqs_sharding )
833859 freqs = freqs [:, :, jnp .newaxis , :] # shape: [B, S, 1, half_dim]
834860
835861 if self .interleave :
@@ -846,7 +872,14 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
846872
847873 inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
848874 # Apply the rotary transformation via complex multiplication.
849- rotated = inputs_complex * freqs # shape: [B, S, N, half_dim]
875+ rotated_sharding = (
876+ NamedSharding (self .mesh , nn .logical_to_mesh_axes (("activation_batch" , "activation_length_no_exp" , None , None )))
877+ if self .shard_mode == ShardMode .EXPLICIT
878+ else None
879+ )
880+ rotated = jnp .einsum (
881+ "ijkl, ijml->ijkl" , inputs_complex , freqs , out_sharding = rotated_sharding
882+ ) # shape: [B, S, N, half_dim]
850883 # Convert the complex result back to a real tensor.
851884 # Split the complex number into its real and imaginary parts.
852885 # [real1, real2, ..., img1, img2, ...]
0 commit comments