@@ -112,28 +112,6 @@ def eager_attention_forward(
112
112
113
113
class QEffQwen3MoeSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
114
114
def forward (self , hidden_states : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
115
- # # breakpoint()
116
- # B, S, D = hidden_states.shape # [1, 8, 2304]
117
- # hidden_states = hidden_states.reshape(-1, D) # [8, 2304]
118
- # T = hidden_states.size(0) # 8 tokens
119
- # router_logits = self.gate(hidden_states) # [8, 8]
120
- # probs = F.softmax(router_logits, dim=-1) # [8, 8]
121
-
122
- # topk_scores, topk_indices = torch.topk(probs, self.top_k, dim=-1) # [8, top_k] → topk_k is 2 for Grok1
123
- # topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # normalize per-token
124
- # topk_scores = topk_scores.to(hidden_states.dtype) # [8, top_k]
125
- # route = torch.zeros((T, self.num_experts), dtype=hidden_states.dtype)
126
- # route.scatter_(1, topk_indices, topk_scores) # [8, num_experts]
127
- # final_output = torch.zeros_like(hidden_states) # [8, 2304]
128
-
129
- # for e, expert in enumerate(self.experts):
130
- # scores = route[:, e].unsqueeze(1) # [8, 1]
131
- # masked_out = torch.where(
132
- # scores > 0, expert(hidden_states) * scores, 0.0
133
- # ) # # [8, 2304] × [8, 1] → [8, 2304]
134
- # final_output += masked_out # accumulate expert outputs
135
- # return final_output.reshape(B, S, D), router_logits # ([1, 8, 2304], [8, num_experts])
136
-
137
115
B , S , H = hidden_states .shape
138
116
T = B * S
139
117
x = hidden_states .view (T , H )
@@ -145,119 +123,40 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
145
123
top_w /= top_w .sum (- 1 , keepdim = True )
146
124
top_w = top_w .to (x .dtype )
147
125
148
- # Create 2 expert idx based on the topk
149
- expert1_idx , expert2_idx , expert3_idx , expert4_idx , expert5_idx , expert6_idx , expert7_idx , expert8_idx = (
150
- top_i [:, 0 ],
151
- top_i [:, 1 ],
152
- top_i [:, 2 ],
153
- top_i [:, 3 ],
154
- top_i [:, 4 ],
155
- top_i [:, 5 ],
156
- top_i [:, 6 ],
157
- top_i [:, 7 ],
158
- ) # [T]
159
- weight1 , weight2 , weight3 , weight4 , weight5 , weight6 , weight7 , weight8 = (
160
- top_w [:, 0 ],
161
- top_w [:, 1 ],
162
- top_w [:, 2 ],
163
- top_w [:, 3 ],
164
- top_w [:, 4 ],
165
- top_w [:, 5 ],
166
- top_w [:, 6 ],
167
- top_w [:, 7 ],
168
- ) # [T]
169
-
170
- Inter = 768
171
- upgate1 = x .new_zeros ((T , Inter ))
172
- upgate2 = x .new_zeros ((T , Inter ))
173
- upgate3 = x .new_zeros ((T , Inter ))
174
- upgate4 = x .new_zeros ((T , Inter ))
175
- upgate5 = x .new_zeros ((T , Inter ))
176
- upgate6 = x .new_zeros ((T , Inter ))
177
- upgate7 = x .new_zeros ((T , Inter ))
178
- upgate8 = x .new_zeros ((T , Inter ))
179
-
180
- expert_out1 = x .new_zeros ((T , H ))
181
- expert_out2 = x .new_zeros ((T , H ))
182
- expert_out3 = x .new_zeros ((T , H ))
183
- expert_out4 = x .new_zeros ((T , H ))
184
- expert_out5 = x .new_zeros ((T , H ))
185
- expert_out6 = x .new_zeros ((T , H ))
186
- expert_out7 = x .new_zeros ((T , H ))
187
- expert_out8 = x .new_zeros ((T , H ))
126
+ expert_idx = []
127
+ weights = []
128
+ for i in range (self .top_k ):
129
+ expert_idx .append (top_i [:, i ])
130
+ weights .append (top_w [:, i ])
131
+
132
+ # I = self.config.ffn_dim
133
+ Inter = 768 # TODO: Find a way to identify from config # Intermediate Size
134
+ upgate = []
135
+ expert_out = []
136
+ for i in range (self .top_k ):
137
+ upgate .append (x .new_zeros ((T , Inter )))
138
+ expert_out .append (x .new_zeros ((T , H )))
188
139
189
140
for e in range (self .num_experts ):
190
141
exp = self .experts [e ]
191
- mask1 = (expert1_idx == e ).unsqueeze (1 ) # [T, 1]
192
- mask2 = (expert2_idx == e ).unsqueeze (1 ) # [T, 1]
193
- mask3 = (expert3_idx == e ).unsqueeze (1 ) # [T, 1]
194
- mask4 = (expert4_idx == e ).unsqueeze (1 ) # [T, 1]
195
- mask5 = (expert5_idx == e ).unsqueeze (1 ) # [T, 1]
196
- mask6 = (expert6_idx == e ).unsqueeze (1 ) # [T, 1]
197
- mask7 = (expert7_idx == e ).unsqueeze (1 ) # [T, 1]
198
- mask8 = (expert8_idx == e ).unsqueeze (1 ) # [T, 1]
199
-
200
- # breakpoint()
142
+ mask = []
143
+ for i in range (self .top_k ):
144
+ mask .append ((expert_idx [i ] == e ).unsqueeze (1 ))
201
145
hidden_gate = (exp .act_fn (exp .gate_proj (x ))) * exp .up_proj (x )
202
- # hidden_gate=exp.down_proj(hidden_gate)
203
-
204
- # Accumulate weighted contributions
205
- upgate1 += torch .where (mask1 , hidden_gate , torch .zeros_like (upgate1 ))
206
- upgate2 += torch .where (mask2 , hidden_gate , torch .zeros_like (upgate2 ))
207
- upgate3 += torch .where (mask3 , hidden_gate , torch .zeros_like (upgate3 ))
208
- upgate4 += torch .where (mask4 , hidden_gate , torch .zeros_like (upgate4 ))
209
- upgate5 += torch .where (mask5 , hidden_gate , torch .zeros_like (upgate5 ))
210
- upgate6 += torch .where (mask6 , hidden_gate , torch .zeros_like (upgate6 ))
211
- upgate7 += torch .where (mask7 , hidden_gate , torch .zeros_like (upgate7 ))
212
- upgate8 += torch .where (mask8 , hidden_gate , torch .zeros_like (upgate8 ))
146
+ for i in range (self .top_k ):
147
+ upgate [i ] += torch .where (mask [i ], hidden_gate , torch .zeros_like (upgate [i ]))
213
148
214
149
for e in range (self .num_experts ):
215
150
exp = self .experts [e ]
216
- mask1 = (expert1_idx == e ).unsqueeze (1 )
217
- mask2 = (expert2_idx == e ).unsqueeze (1 )
218
- mask3 = (expert3_idx == e ).unsqueeze (1 ) # [T, 1]
219
- mask4 = (expert4_idx == e ).unsqueeze (1 ) # [T, 1]
220
- mask5 = (expert5_idx == e ).unsqueeze (1 ) # [T, 1]
221
- mask6 = (expert6_idx == e ).unsqueeze (1 ) # [T, 1]
222
- mask7 = (expert7_idx == e ).unsqueeze (1 ) # [T, 1]
223
- mask8 = (expert8_idx == e ).unsqueeze (1 )
224
- # breakpoint()
225
- expert_out1 += torch .where (
226
- mask1 , exp .down_proj (upgate1 ) * weight1 .unsqueeze (1 ), torch .zeros_like (expert_out1 )
227
- )
228
- expert_out2 += torch .where (
229
- mask2 , exp .down_proj (upgate2 ) * weight2 .unsqueeze (1 ), torch .zeros_like (expert_out2 )
230
- )
231
- expert_out3 += torch .where (
232
- mask3 , exp .down_proj (upgate3 ) * weight3 .unsqueeze (1 ), torch .zeros_like (expert_out3 )
233
- )
234
- expert_out4 += torch .where (
235
- mask4 , exp .down_proj (upgate4 ) * weight4 .unsqueeze (1 ), torch .zeros_like (expert_out4 )
236
- )
237
- expert_out5 += torch .where (
238
- mask5 , exp .down_proj (upgate5 ) * weight5 .unsqueeze (1 ), torch .zeros_like (expert_out5 )
239
- )
240
- expert_out6 += torch .where (
241
- mask6 , exp .down_proj (upgate6 ) * weight6 .unsqueeze (1 ), torch .zeros_like (expert_out6 )
242
- )
243
- expert_out7 += torch .where (
244
- mask7 , exp .down_proj (upgate7 ) * weight7 .unsqueeze (1 ), torch .zeros_like (expert_out7 )
245
- )
246
- expert_out8 += torch .where (
247
- mask8 , exp .down_proj (upgate8 ) * weight8 .unsqueeze (1 ), torch .zeros_like (expert_out8 )
248
- )
249
-
250
- expert_out = (
251
- expert_out1
252
- + expert_out2
253
- + expert_out3
254
- + expert_out4
255
- + expert_out5
256
- + expert_out6
257
- + expert_out7
258
- + expert_out8
259
- )
260
- return expert_out .view (B , S , H ), router_logits
151
+ mask = []
152
+ for i in range (self .top_k ):
153
+ mask .append ((expert_idx [i ] == e ).unsqueeze (1 ))
154
+ expert_out [i ] += torch .where (
155
+ mask [i ], exp .down_proj (upgate [i ]) * (weights [i ].unsqueeze (1 )), torch .zeros_like (expert_out [i ])
156
+ )
157
+
158
+ expert_out_sum = sum (expert_out )
159
+ return expert_out_sum .view (B , S , H ), router_logits
261
160
262
161
263
162
class QEffQwen3MoeAttention (Qwen3MoeAttention ):
0 commit comments