@@ -216,37 +216,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216
216
""" """
217
217
batch_size , sequence_length , hidden_dim = hidden_states .shape
218
218
hidden_states = hidden_states .view (- 1 , hidden_dim )
219
- # router_logits: (batch * sequence_length, n_experts)
220
- router_logits = self .gate (hidden_states )
221
-
219
+
220
+ # Compute routing logits for selecting experts
221
+ router_logits = self .gate (hidden_states ) # Shape: (batch * seq_len, num_experts)
222
+
223
+ # Compute routing probabilities and select top-k experts per token
222
224
routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
223
225
routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
224
- routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
225
- # we cast back to the input dtype
226
- routing_weights = routing_weights .to (hidden_states .dtype )
226
+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True ) # Normalize weights
227
+ routing_weights = routing_weights .to (hidden_states .dtype ) # Ensure correct dtype
227
228
229
+ # Initialize final output tensor
228
230
final_hidden_states = torch .zeros (
229
231
(batch_size * sequence_length , hidden_dim ), dtype = hidden_states .dtype , device = hidden_states .device
230
232
)
233
+ # One-hot encode selected experts (batch * seq_len, top_k, num_experts)
234
+ expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts )
235
+ expert_mask = expert_mask .to (hidden_states .dtype ) # Ensure dtype matches for efficient computation
236
+
237
+ # Compute all expert outputs in parallel (batch * seq_len, num_experts, hidden_dim)
238
+ expert_outputs = torch .stack ([self .experts [i ](hidden_states ) for i in range (self .num_experts )], dim = 1 )
239
+
240
+ # Efficient expert selection using matrix multiplication
241
+ selected_expert_outputs = torch .einsum ("bte,beh->bth" , expert_mask , expert_outputs )
231
242
232
- # One hot encode the selected experts to create an expert mask
233
- # this will be used to easily index which expert is going to be sollicitated
234
- expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 )
235
-
236
- # Loop over all available experts in the model and perform the computation on each expert
237
- for expert_idx in range (self .num_experts ):
238
- expert_layer = self .experts [expert_idx ]
239
- expert_mask_tr = expert_mask [expert_idx ].transpose (0 , 1 )
240
- current_hidden_states = expert_layer (hidden_states ) * (((routing_weights * expert_mask_tr ).sum (1 ))[:, None ])
241
- current_hidden_states = torch .where (
242
- (routing_weights * expert_mask_tr ).sum (1 ).to (torch .bool )[:, None ],
243
- current_hidden_states ,
244
- torch .tensor (0.0 ),
245
- )
246
- final_hidden_states = final_hidden_states + current_hidden_states
247
- final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
248
- return final_hidden_states , router_logits
249
243
244
+ # Multiply by routing weights and sum over top_k experts
245
+ final_hidden_states = (selected_expert_outputs * routing_weights .unsqueeze (- 1 )).sum (dim = 1 )
246
+
247
+ # Reshape back to original dimensions
248
+ final_hidden_states = final_hidden_states .view (batch_size , sequence_length , hidden_dim )
249
+
250
+ return final_hidden_states , router_logits
250
251
251
252
class QeffMixtralDecoderLayer (MixtralDecoderLayer ):
252
253
"""
0 commit comments