Skip to content

Commit 74f4484

Browse files
committed
adding classes
Signed-off-by: Amit Raj <[email protected]>
1 parent 1e27fe7 commit 74f4484

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -216,37 +216,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216216
""" """
217217
batch_size, sequence_length, hidden_dim = hidden_states.shape
218218
hidden_states = hidden_states.view(-1, hidden_dim)
219-
# router_logits: (batch * sequence_length, n_experts)
220-
router_logits = self.gate(hidden_states)
221-
219+
220+
# Compute routing logits for selecting experts
221+
router_logits = self.gate(hidden_states) # Shape: (batch * seq_len, num_experts)
222+
223+
# Compute routing probabilities and select top-k experts per token
222224
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
223225
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
224-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
225-
# we cast back to the input dtype
226-
routing_weights = routing_weights.to(hidden_states.dtype)
226+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Normalize weights
227+
routing_weights = routing_weights.to(hidden_states.dtype) # Ensure correct dtype
227228

229+
# Initialize final output tensor
228230
final_hidden_states = torch.zeros(
229231
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
230232
)
233+
# One-hot encode selected experts (batch * seq_len, top_k, num_experts)
234+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts)
235+
expert_mask = expert_mask.to(hidden_states.dtype) # Ensure dtype matches for efficient computation
236+
237+
# Compute all expert outputs in parallel (batch * seq_len, num_experts, hidden_dim)
238+
expert_outputs = torch.stack([self.experts[i](hidden_states) for i in range(self.num_experts)], dim=1)
239+
240+
# Efficient expert selection using matrix multiplication
241+
selected_expert_outputs = torch.einsum("bte,beh->bth", expert_mask, expert_outputs)
231242

232-
# One hot encode the selected experts to create an expert mask
233-
# this will be used to easily index which expert is going to be sollicitated
234-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
235-
236-
# Loop over all available experts in the model and perform the computation on each expert
237-
for expert_idx in range(self.num_experts):
238-
expert_layer = self.experts[expert_idx]
239-
expert_mask_tr = expert_mask[expert_idx].transpose(0, 1)
240-
current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None])
241-
current_hidden_states = torch.where(
242-
(routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None],
243-
current_hidden_states,
244-
torch.tensor(0.0),
245-
)
246-
final_hidden_states = final_hidden_states + current_hidden_states
247-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
248-
return final_hidden_states, router_logits
249243

244+
# Multiply by routing weights and sum over top_k experts
245+
final_hidden_states = (selected_expert_outputs * routing_weights.unsqueeze(-1)).sum(dim=1)
246+
247+
# Reshape back to original dimensions
248+
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
249+
250+
return final_hidden_states, router_logits
250251

251252
class QeffMixtralDecoderLayer(MixtralDecoderLayer):
252253
"""

0 commit comments

Comments
 (0)