@@ -263,9 +263,11 @@ def __init__(
263263 self ,
264264 min_timescale : int ,
265265 max_timescale : int ,
266+ mesh : Mesh ,
266267 embedding_dims : int = 0 ,
267268 cast_as_fprop_dtype : bool = True ,
268269 fprop_dtype : DType = jnp .bfloat16 ,
270+ shard_mode : ShardMode = ShardMode .AUTO ,
269271 # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
270272 # TODO: Remove when bridge no longer needed
271273 rope_linear_scaling_factor : float = 1.0 ,
@@ -285,9 +287,11 @@ def __init__(
285287 """
286288 self .min_timescale = min_timescale
287289 self .max_timescale = max_timescale
290+ self .mesh = mesh
288291 self .embedding_dims = embedding_dims
289292 self .cast_as_fprop_dtype = cast_as_fprop_dtype
290293 self .fprop_dtype = fprop_dtype
294+ self .shard_mode = shard_mode
291295 self .rope_linear_scaling_factor = rope_linear_scaling_factor
292296
293297 if self .embedding_dims % 2 :
@@ -384,6 +388,7 @@ def qwen3_next_rotary_embedding_as_linen(
384388 * ,
385389 min_timescale : int ,
386390 max_timescale : int ,
391+ mesh : Mesh ,
387392 embedding_dims : int = 0 ,
388393 partial_rotary_factor : float = 0.25 ,
389394 cast_as_fprop_dtype : bool = True ,
@@ -407,6 +412,7 @@ def qwen3_next_rotary_embedding_as_linen(
407412 Qwen3NextRotaryEmbedding ,
408413 min_timescale = min_timescale ,
409414 max_timescale = max_timescale ,
415+ mesh = mesh ,
410416 embedding_dims = embedding_dims ,
411417 partial_rotary_factor = partial_rotary_factor ,
412418 cast_as_fprop_dtype = cast_as_fprop_dtype ,
@@ -423,6 +429,7 @@ def __init__(
423429 self ,
424430 min_timescale : int ,
425431 max_timescale : int ,
432+ mesh : Mesh ,
426433 embedding_dims : int = 0 ,
427434 cast_as_fprop_dtype : bool = True ,
428435 fprop_dtype : DType = jnp .bfloat16 ,
@@ -447,6 +454,7 @@ def __init__(
447454 super ().__init__ (
448455 min_timescale = min_timescale ,
449456 max_timescale = max_timescale ,
457+ mesh = mesh ,
450458 embedding_dims = self .rotary_dim ,
451459 cast_as_fprop_dtype = cast_as_fprop_dtype ,
452460 fprop_dtype = fprop_dtype ,
@@ -478,10 +486,12 @@ def __init__(
478486 self ,
479487 min_timescale : int ,
480488 max_timescale : int ,
489+ mesh : Mesh ,
481490 embedding_dims : int = 0 ,
482491 cast_as_fprop_dtype : bool = True ,
483492 fprop_dtype : DType = jnp .bfloat16 ,
484493 use_scale : bool = True ,
494+ shard_mode : ShardMode = ShardMode .AUTO ,
485495 # Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen.
486496 # TODO: Remove when bridge no longer needed
487497 rngs : nnx .Rngs = None ,
@@ -505,6 +515,8 @@ def __init__(
505515 embedding_dims = embedding_dims ,
506516 cast_as_fprop_dtype = cast_as_fprop_dtype ,
507517 fprop_dtype = fprop_dtype ,
518+ mesh = mesh ,
519+ shard_mode = shard_mode ,
508520 rngs = rngs ,
509521 )
510522
@@ -613,6 +625,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
613625def yarn_rotary_embedding_as_linen (
614626 * ,
615627 embedding_dims : int ,
628+ mesh : Mesh ,
616629 max_position_embeddings : int = 4096 * 4 ,
617630 original_max_position_embeddings : int = 4096 ,
618631 beta_fast : float = 32 ,
@@ -625,6 +638,7 @@ def yarn_rotary_embedding_as_linen(
625638 interleave : bool = True ,
626639 truncate : bool = True ,
627640 attention_scaling : bool = False ,
641+ shard_mode : ShardMode = ShardMode .AUTO ,
628642):
629643 """Initializes the YarnRotaryEmbedding module and returns it as a Linen module.
630644
@@ -643,6 +657,7 @@ def yarn_rotary_embedding_as_linen(
643657 return nnx_wrappers .to_linen (
644658 YarnRotaryEmbedding ,
645659 embedding_dims = embedding_dims ,
660+ mesh = mesh ,
646661 max_position_embeddings = max_position_embeddings ,
647662 original_max_position_embeddings = original_max_position_embeddings ,
648663 beta_fast = beta_fast ,
@@ -656,6 +671,7 @@ def yarn_rotary_embedding_as_linen(
656671 interleave = interleave ,
657672 truncate = truncate ,
658673 attention_scaling = attention_scaling ,
674+ shard_mode = shard_mode ,
659675 )
660676
661677
@@ -685,6 +701,7 @@ class YarnRotaryEmbedding(nnx.Module):
685701 def __init__ (
686702 self ,
687703 embedding_dims : int ,
704+ mesh : Mesh ,
688705 max_position_embeddings : int = 4096 * 4 ,
689706 original_max_position_embeddings : int = 4096 ,
690707 beta_fast : float = 32 ,
@@ -693,6 +710,7 @@ def __init__(
693710 rope_factor : float = 40 ,
694711 cast_as_fprop_dtype : bool = True ,
695712 fprop_dtype : DType = jnp .bfloat16 ,
713+ shard_mode : ShardMode = ShardMode .AUTO ,
696714 interleave = True ,
697715 truncate = True ,
698716 attention_scaling = False ,
@@ -712,6 +730,13 @@ def __init__(
712730 self .fprop_dtype = fprop_dtype
713731 self .interleave = interleave
714732 self .truncate = truncate
733+ self .mesh = mesh
734+ self .shard_mode = shard_mode
735+ self .freqs_sharding = (
736+ NamedSharding (mesh , nn .logical_to_mesh_axes (("activation_batch" , "activation_length_no_exp" , "q_heads" )))
737+ if shard_mode == ShardMode .EXPLICIT
738+ else None
739+ )
715740 self .attention_scaling = attention_scaling
716741
717742 if self .embedding_dims % 2 :
@@ -811,7 +836,8 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
811836 # Lookup the precomputed frequencies using the position indices.
812837 # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
813838 # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads.
814- freqs = jnp .take (self .freqs_cis , position , axis = 0 ) # shape: [B, S, half_dim]
839+ # freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
840+ freqs = self .freqs_cis .at [position ].get (out_sharding = self .freqs_sharding )
815841 freqs = freqs [:, :, jnp .newaxis , :] # shape: [B, S, 1, half_dim]
816842
817843 if self .interleave :
@@ -828,7 +854,14 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
828854
829855 inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
830856 # Apply the rotary transformation via complex multiplication.
831- rotated = inputs_complex * freqs # shape: [B, S, N, half_dim]
857+ rotated_sharding = (
858+ NamedSharding (self .mesh , nn .logical_to_mesh_axes (("activation_batch" , "activation_length_no_exp" , None , None )))
859+ if self .shard_mode == ShardMode .EXPLICIT
860+ else None
861+ )
862+ rotated = jnp .einsum (
863+ "ijkl, ijml->ijkl" , inputs_complex , freqs , out_sharding = rotated_sharding
864+ ) # shape: [B, S, N, half_dim]
832865 # Convert the complex result back to a real tensor.
833866 # Split the complex number into its real and imaginary parts.
834867 # [real1, real2, ..., img1, img2, ...]
0 commit comments