1919 "map_score_function" ,
2020]
2121
22+
2223def map_score_function (score_function : str ) -> int :
2324 score_function_map = {"sigmoid" : 0 , "softmax" : 1 }
24- assert score_function in score_function_map , \
25- f"score_function must be 'sigmoid' or 'softmax', got { score_function } "
25+ assert (
26+ score_function in score_function_map
27+ ), f"score_function must be 'sigmoid' or 'softmax', got { score_function } "
2628 return score_function_map [score_function ]
2729
30+
2831class FusedTopkWithScoreFunctionFwdPrimitive (BasePrimitive ):
2932 """
3033 Fused TopK with Score Function Forward Primitive
3134 """
3235
3336 name = "te_fused_topk_with_score_function_forward_ffi"
3437 multiple_results = True # Returns (probs, routing_map, intermediate_output)
35- impl_static_args = (2 , 3 , 4 , 5 , 6 , 7 ,) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
38+ impl_static_args = (
39+ 2 ,
40+ 3 ,
41+ 4 ,
42+ 5 ,
43+ 6 ,
44+ 7 ,
45+ ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
3646 inner_primitive = None
3747 outer_primitive = None
3848
@@ -52,7 +62,7 @@ def abstract(
5262 te_fused_topk_with_score_function_forward abstract
5363 """
5464 dtype = dtypes .canonicalize_dtype (logits_aval .dtype )
55- assert len (logits_aval .shape ) == 3 # (batch, seqlen, num_experts)
65+ assert len (logits_aval .shape ) == 3 # (batch, seqlen, num_experts)
5666
5767 probs_aval = logits_aval .update (shape = logits_aval .shape , dtype = dtype )
5868 routing_map_aval = logits_aval .update (shape = logits_aval .shape , dtype = jnp .bool_ )
@@ -78,14 +88,14 @@ def lowering(
7888 """
7989 logits_type = ir .RankedTensorType (logits .type )
8090 logits_shape = logits_type .shape
81- assert len (logits_shape ) == 3 # (batch, seqlen, num_experts)
91+ assert len (logits_shape ) == 3 # (batch, seqlen, num_experts)
8292 (batch , seqlen , num_experts ) = logits_shape
8393
8494 return ffi .ffi_lowering (FusedTopkWithScoreFunctionFwdPrimitive .name )(
8595 ctx ,
8696 logits ,
8797 expert_bias ,
88- num_tokens = batch * seqlen ,
98+ num_tokens = batch * seqlen ,
8999 num_experts = num_experts ,
90100 topk = topk ,
91101 use_pre_softmax = use_pre_softmax ,
@@ -107,17 +117,20 @@ def impl(
107117 score_function ,
108118 ):
109119 assert FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive is not None
110- (probs , routing_map , intermediate_output ) = FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive .bind (
111- logits ,
112- expert_bias ,
113- topk = topk ,
114- use_pre_softmax = use_pre_softmax ,
115- num_groups = num_groups ,
116- group_topk = group_topk ,
117- scaling_factor = scaling_factor ,
118- score_function = score_function ,
120+ (probs , routing_map , intermediate_output ) = (
121+ FusedTopkWithScoreFunctionFwdPrimitive .inner_primitive .bind (
122+ logits ,
123+ expert_bias ,
124+ topk = topk ,
125+ use_pre_softmax = use_pre_softmax ,
126+ num_groups = num_groups ,
127+ group_topk = group_topk ,
128+ scaling_factor = scaling_factor ,
129+ score_function = score_function ,
130+ )
119131 )
120132 return probs , routing_map , intermediate_output
133+
121134 @staticmethod
122135 def batcher (
123136 batched_args ,
@@ -129,7 +142,9 @@ def batcher(
129142 scaling_factor ,
130143 score_function ,
131144 ):
132- raise NotImplementedError ("Batcher not implemented for FusedTopkWithScoreFunctionFwdPrimitive" )
145+ raise NotImplementedError (
146+ "Batcher not implemented for FusedTopkWithScoreFunctionFwdPrimitive"
147+ )
133148
134149 @staticmethod
135150 def infer_sharding_from_operands (
@@ -169,7 +184,8 @@ def partition(
169184 del result_infos
170185 out_shardings = (arg_infos [0 ].sharding , arg_infos [0 ].sharding , arg_infos [0 ].sharding )
171186 arg_shardings = (arg_infos [0 ].sharding , arg_infos [1 ].sharding )
172- impl = partial (FusedTopkWithScoreFunctionFwdPrimitive .impl ,
187+ impl = partial (
188+ FusedTopkWithScoreFunctionFwdPrimitive .impl ,
173189 topk = topk ,
174190 use_pre_softmax = use_pre_softmax ,
175191 num_groups = num_groups ,
@@ -261,15 +277,15 @@ def lowering(
261277 """
262278 intermediate_output_type = ir .RankedTensorType (intermediate_output .type )
263279 intermediate_output_shape = intermediate_output_type .shape
264- assert len (intermediate_output_shape ) == 3 # (batch, seqlen, num_experts)
280+ assert len (intermediate_output_shape ) == 3 # (batch, seqlen, num_experts)
265281 (batch , seqlen , num_experts ) = intermediate_output_shape
266282
267283 return ffi .ffi_lowering (FusedTopkWithScoreFunctionBwdPrimitive .name )(
268284 ctx ,
269285 routing_map ,
270286 intermediate_output ,
271287 grad_probs ,
272- num_tokens = batch * seqlen ,
288+ num_tokens = batch * seqlen ,
273289 num_experts = num_experts ,
274290 topk = topk ,
275291 use_pre_softmax = use_pre_softmax ,
@@ -307,7 +323,9 @@ def batcher(
307323 scaling_factor ,
308324 score_function ,
309325 ):
310- raise NotImplementedError ("Batcher not implemented for FusedTopkWithScoreFunctionBwdPrimitive" )
326+ raise NotImplementedError (
327+ "Batcher not implemented for FusedTopkWithScoreFunctionBwdPrimitive"
328+ )
311329
312330 @staticmethod
313331 def infer_sharding_from_operands (
@@ -411,6 +429,7 @@ def fused_topk_with_score_function_fwd(
411429 score_function = score_function ,
412430 )
413431
432+
414433def fused_topk_with_score_function_bwd (
415434 routing_map : jnp .ndarray ,
416435 intermediate_output : jnp .ndarray ,
0 commit comments