Skip to content

Commit 59047c1

Browse files
committed
Adding shardy and fix an arguments issue of backward
Signed-off-by: Ming Huang <[email protected]>
1 parent 8cec022 commit 59047c1

File tree

1 file changed

+71
-6
lines changed
  • transformer_engine/jax/cpp_extensions

1 file changed

+71
-6
lines changed

transformer_engine/jax/cpp_extensions/router.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax.numpy as jnp
99
from jax import dtypes, ffi
1010
from jax.interpreters.mlir import ir
11+
from jax.experimental.custom_partitioning import SdyShardingRule
1112

1213
from .base import BasePrimitive, register_primitive
1314

@@ -106,7 +107,7 @@ def impl(
106107
score_function,
107108
):
108109
assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None
109-
return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind(
110+
(probs, routing_map, intermediate_output) = FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind(
110111
logits,
111112
expert_bias,
112113
topk=topk,
@@ -116,7 +117,7 @@ def impl(
116117
scaling_factor=scaling_factor,
117118
score_function=score_function,
118119
)
119-
120+
return probs, routing_map, intermediate_output
120121
@staticmethod
121122
def batcher(
122123
batched_args,
@@ -178,6 +179,40 @@ def partition(
178179
)
179180
return mesh, impl, out_shardings, arg_shardings
180181

182+
@staticmethod
183+
def shardy_sharding_rule(
184+
topk,
185+
use_pre_softmax,
186+
num_groups,
187+
group_topk,
188+
scaling_factor,
189+
score_function,
190+
mesh,
191+
operand_types,
192+
result_types,
193+
):
194+
del (
195+
topk,
196+
use_pre_softmax,
197+
num_groups,
198+
group_topk,
199+
scaling_factor,
200+
score_function,
201+
mesh,
202+
result_types,
203+
)
204+
205+
prefix = "Router_"
206+
logits_spec = (prefix + "batch", prefix + "seqlen", prefix + "experts")
207+
expert_bias_spec = (prefix + "experts",)
208+
209+
output_spec = (prefix + "batch", prefix + "seqlen", prefix + "experts")
210+
211+
return SdyShardingRule(
212+
(logits_spec, expert_bias_spec),
213+
(output_spec, output_spec, output_spec),
214+
)
215+
181216

182217
register_primitive(FusedTopkWithScoreFunctionFwdPrimitive)
183218

@@ -317,6 +352,36 @@ def partition(
317352
)
318353
return mesh, impl, out_shardings, arg_shardings
319354

355+
@staticmethod
356+
def shardy_sharding_rule(
357+
topk,
358+
use_pre_softmax,
359+
scaling_factor,
360+
score_function,
361+
mesh,
362+
operand_types,
363+
result_types,
364+
):
365+
del (
366+
topk,
367+
use_pre_softmax,
368+
scaling_factor,
369+
score_function,
370+
mesh,
371+
result_types,
372+
)
373+
374+
prefix = "RouterBwd_"
375+
376+
input_spec = (prefix + "batch", prefix + "seqlen", prefix + "experts")
377+
378+
output_spec = (prefix + "batch", prefix + "seqlen", prefix + "experts")
379+
380+
return SdyShardingRule(
381+
(input_spec, input_spec, input_spec),
382+
(output_spec,),
383+
)
384+
320385

321386
register_primitive(FusedTopkWithScoreFunctionBwdPrimitive)
322387

@@ -360,8 +425,8 @@ def fused_topk_with_score_function_bwd(
360425
routing_map,
361426
intermediate_output,
362427
grad_probs,
363-
topk,
364-
use_pre_softmax,
365-
scaling_factor,
366-
score_function,
428+
topk=topk,
429+
use_pre_softmax=use_pre_softmax,
430+
scaling_factor=scaling_factor,
431+
score_function=score_function,
367432
)

0 commit comments

Comments
 (0)