88import jax .numpy as jnp
99from jax import dtypes , ffi
1010from jax .interpreters .mlir import ir
11+ from jax .experimental .custom_partitioning import SdyShardingRule
1112
1213from .base import BasePrimitive , register_primitive
1314
@@ -106,7 +107,7 @@ def impl(
106107 score_function ,
107108 ):
108109 assert FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive is not None
109- return FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive .bind (
110+ ( probs , routing_map , intermediate_output ) = FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive .bind (
110111 logits ,
111112 expert_bias ,
112113 topk = topk ,
@@ -116,7 +117,7 @@ def impl(
116117 scaling_factor = scaling_factor ,
117118 score_function = score_function ,
118119 )
119-
120+ return probs , routing_map , intermediate_output
120121 @staticmethod
121122 def batcher (
122123 batched_args ,
@@ -178,6 +179,40 @@ def partition(
178179 )
179180 return mesh , impl , out_shardings , arg_shardings
180181
182+ @staticmethod
183+ def shardy_sharding_rule (
184+ topk ,
185+ use_pre_softmax ,
186+ num_groups ,
187+ group_topk ,
188+ scaling_factor ,
189+ score_function ,
190+ mesh ,
191+ operand_types ,
192+ result_types ,
193+ ):
194+ del (
195+ topk ,
196+ use_pre_softmax ,
197+ num_groups ,
198+ group_topk ,
199+ scaling_factor ,
200+ score_function ,
201+ mesh ,
202+ result_types ,
203+ )
204+
205+ prefix = "Router_"
206+ logits_spec = (prefix + "batch" , prefix + "seqlen" , prefix + "experts" )
207+ expert_bias_spec = (prefix + "experts" ,)
208+
209+ output_spec = (prefix + "batch" , prefix + "seqlen" , prefix + "experts" )
210+
211+ return SdyShardingRule (
212+ (logits_spec , expert_bias_spec ),
213+ (output_spec , output_spec , output_spec ),
214+ )
215+
181216
182217register_primitive (FusedTopkWithScoreFunctionFwdPrimitive )
183218
@@ -317,6 +352,36 @@ def partition(
317352 )
318353 return mesh , impl , out_shardings , arg_shardings
319354
355+ @staticmethod
356+ def shardy_sharding_rule (
357+ topk ,
358+ use_pre_softmax ,
359+ scaling_factor ,
360+ score_function ,
361+ mesh ,
362+ operand_types ,
363+ result_types ,
364+ ):
365+ del (
366+ topk ,
367+ use_pre_softmax ,
368+ scaling_factor ,
369+ score_function ,
370+ mesh ,
371+ result_types ,
372+ )
373+
374+ prefix = "RouterBwd_"
375+
376+ input_spec = (prefix + "batch" , prefix + "seqlen" , prefix + "experts" )
377+
378+ output_spec = (prefix + "batch" , prefix + "seqlen" , prefix + "experts" )
379+
380+ return SdyShardingRule (
381+ (input_spec , input_spec , input_spec ),
382+ (output_spec ,),
383+ )
384+
320385
321386register_primitive (FusedTopkWithScoreFunctionBwdPrimitive )
322387
@@ -360,8 +425,8 @@ def fused_topk_with_score_function_bwd(
360425 routing_map ,
361426 intermediate_output ,
362427 grad_probs ,
363- topk ,
364- use_pre_softmax ,
365- scaling_factor ,
366- score_function ,
428+ topk = topk ,
429+ use_pre_softmax = use_pre_softmax ,
430+ scaling_factor = scaling_factor ,
431+ score_function = score_function ,
367432 )
0 commit comments