Skip to content

Commit 7c51f36

Browse files
committed
Dynamic code for Moe Block
Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent 1965328 commit 7c51f36

File tree

1 file changed

+27
-128
lines changed

1 file changed

+27
-128
lines changed

QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py

Lines changed: 27 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,6 @@ def eager_attention_forward(
112112

113113
class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
114114
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
115-
# # breakpoint()
116-
# B, S, D = hidden_states.shape # [1, 8, 2304]
117-
# hidden_states = hidden_states.reshape(-1, D) # [8, 2304]
118-
# T = hidden_states.size(0) # 8 tokens
119-
# router_logits = self.gate(hidden_states) # [8, 8]
120-
# probs = F.softmax(router_logits, dim=-1) # [8, 8]
121-
122-
# topk_scores, topk_indices = torch.topk(probs, self.top_k, dim=-1) # [8, top_k] → topk_k is 2 for Grok1
123-
# topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # normalize per-token
124-
# topk_scores = topk_scores.to(hidden_states.dtype) # [8, top_k]
125-
# route = torch.zeros((T, self.num_experts), dtype=hidden_states.dtype)
126-
# route.scatter_(1, topk_indices, topk_scores) # [8, num_experts]
127-
# final_output = torch.zeros_like(hidden_states) # [8, 2304]
128-
129-
# for e, expert in enumerate(self.experts):
130-
# scores = route[:, e].unsqueeze(1) # [8, 1]
131-
# masked_out = torch.where(
132-
# scores > 0, expert(hidden_states) * scores, 0.0
133-
# ) # # [8, 2304] × [8, 1] → [8, 2304]
134-
# final_output += masked_out # accumulate expert outputs
135-
# return final_output.reshape(B, S, D), router_logits # ([1, 8, 2304], [8, num_experts])
136-
137115
B, S, H = hidden_states.shape
138116
T = B * S
139117
x = hidden_states.view(T, H)
@@ -145,119 +123,40 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
145123
top_w /= top_w.sum(-1, keepdim=True)
146124
top_w = top_w.to(x.dtype)
147125

148-
# Create 2 expert idx based on the topk
149-
expert1_idx, expert2_idx, expert3_idx, expert4_idx, expert5_idx, expert6_idx, expert7_idx, expert8_idx = (
150-
top_i[:, 0],
151-
top_i[:, 1],
152-
top_i[:, 2],
153-
top_i[:, 3],
154-
top_i[:, 4],
155-
top_i[:, 5],
156-
top_i[:, 6],
157-
top_i[:, 7],
158-
) # [T]
159-
weight1, weight2, weight3, weight4, weight5, weight6, weight7, weight8 = (
160-
top_w[:, 0],
161-
top_w[:, 1],
162-
top_w[:, 2],
163-
top_w[:, 3],
164-
top_w[:, 4],
165-
top_w[:, 5],
166-
top_w[:, 6],
167-
top_w[:, 7],
168-
) # [T]
169-
170-
Inter = 768
171-
upgate1 = x.new_zeros((T, Inter))
172-
upgate2 = x.new_zeros((T, Inter))
173-
upgate3 = x.new_zeros((T, Inter))
174-
upgate4 = x.new_zeros((T, Inter))
175-
upgate5 = x.new_zeros((T, Inter))
176-
upgate6 = x.new_zeros((T, Inter))
177-
upgate7 = x.new_zeros((T, Inter))
178-
upgate8 = x.new_zeros((T, Inter))
179-
180-
expert_out1 = x.new_zeros((T, H))
181-
expert_out2 = x.new_zeros((T, H))
182-
expert_out3 = x.new_zeros((T, H))
183-
expert_out4 = x.new_zeros((T, H))
184-
expert_out5 = x.new_zeros((T, H))
185-
expert_out6 = x.new_zeros((T, H))
186-
expert_out7 = x.new_zeros((T, H))
187-
expert_out8 = x.new_zeros((T, H))
126+
expert_idx = []
127+
weights = []
128+
for i in range(self.top_k):
129+
expert_idx.append(top_i[:, i])
130+
weights.append(top_w[:, i])
131+
132+
# I = self.config.ffn_dim
133+
Inter = 768 # TODO: Find a way to identify from config # Intermediate Size
134+
upgate = []
135+
expert_out = []
136+
for i in range(self.top_k):
137+
upgate.append(x.new_zeros((T, Inter)))
138+
expert_out.append(x.new_zeros((T, H)))
188139

189140
for e in range(self.num_experts):
190141
exp = self.experts[e]
191-
mask1 = (expert1_idx == e).unsqueeze(1) # [T, 1]
192-
mask2 = (expert2_idx == e).unsqueeze(1) # [T, 1]
193-
mask3 = (expert3_idx == e).unsqueeze(1) # [T, 1]
194-
mask4 = (expert4_idx == e).unsqueeze(1) # [T, 1]
195-
mask5 = (expert5_idx == e).unsqueeze(1) # [T, 1]
196-
mask6 = (expert6_idx == e).unsqueeze(1) # [T, 1]
197-
mask7 = (expert7_idx == e).unsqueeze(1) # [T, 1]
198-
mask8 = (expert8_idx == e).unsqueeze(1) # [T, 1]
199-
200-
# breakpoint()
142+
mask = []
143+
for i in range(self.top_k):
144+
mask.append((expert_idx[i] == e).unsqueeze(1))
201145
hidden_gate = (exp.act_fn(exp.gate_proj(x))) * exp.up_proj(x)
202-
# hidden_gate=exp.down_proj(hidden_gate)
203-
204-
# Accumulate weighted contributions
205-
upgate1 += torch.where(mask1, hidden_gate, torch.zeros_like(upgate1))
206-
upgate2 += torch.where(mask2, hidden_gate, torch.zeros_like(upgate2))
207-
upgate3 += torch.where(mask3, hidden_gate, torch.zeros_like(upgate3))
208-
upgate4 += torch.where(mask4, hidden_gate, torch.zeros_like(upgate4))
209-
upgate5 += torch.where(mask5, hidden_gate, torch.zeros_like(upgate5))
210-
upgate6 += torch.where(mask6, hidden_gate, torch.zeros_like(upgate6))
211-
upgate7 += torch.where(mask7, hidden_gate, torch.zeros_like(upgate7))
212-
upgate8 += torch.where(mask8, hidden_gate, torch.zeros_like(upgate8))
146+
for i in range(self.top_k):
147+
upgate[i] += torch.where(mask[i], hidden_gate, torch.zeros_like(upgate[i]))
213148

214149
for e in range(self.num_experts):
215150
exp = self.experts[e]
216-
mask1 = (expert1_idx == e).unsqueeze(1)
217-
mask2 = (expert2_idx == e).unsqueeze(1)
218-
mask3 = (expert3_idx == e).unsqueeze(1) # [T, 1]
219-
mask4 = (expert4_idx == e).unsqueeze(1) # [T, 1]
220-
mask5 = (expert5_idx == e).unsqueeze(1) # [T, 1]
221-
mask6 = (expert6_idx == e).unsqueeze(1) # [T, 1]
222-
mask7 = (expert7_idx == e).unsqueeze(1) # [T, 1]
223-
mask8 = (expert8_idx == e).unsqueeze(1)
224-
# breakpoint()
225-
expert_out1 += torch.where(
226-
mask1, exp.down_proj(upgate1) * weight1.unsqueeze(1), torch.zeros_like(expert_out1)
227-
)
228-
expert_out2 += torch.where(
229-
mask2, exp.down_proj(upgate2) * weight2.unsqueeze(1), torch.zeros_like(expert_out2)
230-
)
231-
expert_out3 += torch.where(
232-
mask3, exp.down_proj(upgate3) * weight3.unsqueeze(1), torch.zeros_like(expert_out3)
233-
)
234-
expert_out4 += torch.where(
235-
mask4, exp.down_proj(upgate4) * weight4.unsqueeze(1), torch.zeros_like(expert_out4)
236-
)
237-
expert_out5 += torch.where(
238-
mask5, exp.down_proj(upgate5) * weight5.unsqueeze(1), torch.zeros_like(expert_out5)
239-
)
240-
expert_out6 += torch.where(
241-
mask6, exp.down_proj(upgate6) * weight6.unsqueeze(1), torch.zeros_like(expert_out6)
242-
)
243-
expert_out7 += torch.where(
244-
mask7, exp.down_proj(upgate7) * weight7.unsqueeze(1), torch.zeros_like(expert_out7)
245-
)
246-
expert_out8 += torch.where(
247-
mask8, exp.down_proj(upgate8) * weight8.unsqueeze(1), torch.zeros_like(expert_out8)
248-
)
249-
250-
expert_out = (
251-
expert_out1
252-
+ expert_out2
253-
+ expert_out3
254-
+ expert_out4
255-
+ expert_out5
256-
+ expert_out6
257-
+ expert_out7
258-
+ expert_out8
259-
)
260-
return expert_out.view(B, S, H), router_logits
151+
mask = []
152+
for i in range(self.top_k):
153+
mask.append((expert_idx[i] == e).unsqueeze(1))
154+
expert_out[i] += torch.where(
155+
mask[i], exp.down_proj(upgate[i]) * (weights[i].unsqueeze(1)), torch.zeros_like(expert_out[i])
156+
)
157+
158+
expert_out_sum = sum(expert_out)
159+
return expert_out_sum.view(B, S, H), router_logits
261160

262161

263162
class QEffQwen3MoeAttention(Qwen3MoeAttention):

0 commit comments

Comments
 (0)