@@ -939,23 +939,26 @@ def grouped_topk(hidden_states: torch.Tensor,
939
939
else :
940
940
raise ValueError (f"Unsupported scoring function: { scoring_func } " )
941
941
942
+ num_token = scores .shape [0 ]
942
943
if e_score_correction_bias is not None :
943
944
# Store original scores before applying correction bias. We use biased
944
945
# scores for expert selection but original scores for routing weights
945
946
original_scores = scores
946
947
scores = scores + e_score_correction_bias .unsqueeze (0 )
947
-
948
- num_token = scores .shape [0 ]
949
- group_scores = scores .view (num_token , num_expert_group ,
950
- - 1 ).max (dim = - 1 ).values # [n, n_group]
948
+ group_scores = (scores .view (num_token , num_expert_group ,
949
+ - 1 ).topk (2 , dim = - 1 )[0 ].sum (dim = - 1 ))
950
+ else :
951
+ group_scores = scores .view (num_token , num_expert_group ,
952
+ - 1 ).max (dim = - 1 ).values # [n, n_group]
951
953
group_idx = torch .topk (group_scores , k = topk_group , dim = - 1 ,
952
954
sorted = False )[1 ] # [n, top_k_group]
953
955
group_mask = torch .zeros_like (group_scores ) # [n, n_group]
954
956
group_mask .scatter_ (1 , group_idx , 1 ) # [n, n_group]
955
957
score_mask = group_mask .unsqueeze (- 1 ).expand (
956
958
num_token , num_expert_group ,
957
959
scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
958
- tmp_scores = scores .masked_fill (~ score_mask .bool (), 0.0 ) # [n, e]
960
+ tmp_scores = scores .masked_fill (~ score_mask .bool (),
961
+ float ("-inf" )) # [n, e]
959
962
960
963
if e_score_correction_bias is not None :
961
964
topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
0 commit comments