Skip to content

Commit e80f509

Browse files
vbaddiquic-akuruvil
authored andcommitted
Add MOE updates from Ann
Signed-off-by: Ann Kuruvilla <[email protected]>
1 parent f6df7b0 commit e80f509

File tree

1 file changed

+100
-38
lines changed

1 file changed

+100
-38
lines changed

QEfficient/transformers/models/deepseek_v3/modeling_deepseek_v3.py

+100-38
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from typing import List, Optional, Tuple, Union
32

43
import torch
@@ -11,25 +10,25 @@
1110

1211
# Assuming these are imported from the original DeepseekV3 code
1312
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
14-
DeepseekV3Config,
15-
DeepseekV3RMSNorm,
16-
DeepseekV3MLP,
17-
DeepseekV3MoE,
18-
rotate_half,
19-
repeat_kv,
2013
DeepseekV3Attention,
14+
DeepseekV3Config,
2115
DeepseekV3DecoderLayer,
22-
DeepseekV3Model,
2316
DeepseekV3ForCausalLM,
24-
DeepseekV3PreTrainedModel,
17+
DeepseekV3Model,
18+
DeepseekV3MoE,
2519
logger,
20+
repeat_kv,
21+
rotate_half,
2622
)
23+
2724
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2825

26+
2927
class QEffDeepseekV3RotaryEmbedding(nn.Module):
3028
"""
3129
Adapted from DeepseekV3RotaryEmbedding with static sin/cos caches like QEffLlamaRotaryEmbedding.
3230
"""
31+
3332
def __init__(self, config: DeepseekV3Config, device=None):
3433
super().__init__()
3534
if config.rope_scaling is not None:
@@ -46,9 +45,7 @@ def __init__(self, config: DeepseekV3Config, device=None):
4645

4746
# Precompute static sin/cos caches
4847
self._set_cos_sin_cache(
49-
seq_len=self.original_max_seq_len,
50-
device=self.inv_freq.device,
51-
dtype=torch.get_default_dtype()
48+
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
5249
)
5350

5451
def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -63,12 +60,51 @@ def forward(self, x, position_ids):
6360
seq_len = torch.max(position_ids) + 1
6461
if seq_len > self.max_seq_len_cached:
6562
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
66-
63+
6764
# Use position_ids to slice the precomputed caches
6865
cos = self.cos_cached[position_ids] * self.attention_scaling
6966
sin = self.sin_cached[position_ids] * self.attention_scaling
7067
return cos.to(x.dtype), sin.to(x.dtype)
7168

69+
70+
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
71+
"""Applies Rotary Position Embedding to the query and key tensors.
72+
Args:
73+
q (`torch.Tensor`): The query tensor.
74+
k (`torch.Tensor`): The key tensor.
75+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
76+
sin (`torch.Tensor`): The sine part of the rotary embedding.
77+
position_ids (`torch.Tensor`):
78+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
79+
used to pass offsetted position ids when working with a KV-cache.
80+
unsqueeze_dim (`int`, *optional*, defaults to 1):
81+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
82+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
83+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
84+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
85+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
86+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
87+
Returns:
88+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
89+
"""
90+
# Slice cos and sin using position_ids if they are larger (e.g., precomputed caches)
91+
if cos.shape[-2] > q.shape[-2]:
92+
cos = cos[:, position_ids, :]
93+
sin = sin[:, position_ids, :]
94+
cos = cos.unsqueeze(unsqueeze_dim)
95+
sin = sin.unsqueeze(unsqueeze_dim)
96+
97+
b, h, s, d = q.shape
98+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
99+
100+
b, h, s, d = k.shape
101+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
102+
103+
q_embed = (q * cos) + (rotate_half(q) * sin)
104+
k_embed = (k * cos) + (rotate_half(k) * sin)
105+
return q_embed, k_embed
106+
107+
72108
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
73109
"""Adapted from DeepseekV3's apply_rotary_pos_emb for QEff compatibility with position_ids slicing."""
74110
# Slice cos and sin using position_ids if they are larger (e.g., precomputed caches)
@@ -81,8 +117,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
81117
k_embed = (k * cos) + (rotate_half(k) * sin)
82118
return q_embed.to(q.dtype), k_embed.to(k.dtype)
83119

120+
84121
class QEffDeepseekV3Attention(DeepseekV3Attention):
85122
"""Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling."""
123+
86124
def forward(
87125
self,
88126
hidden_states: torch.Tensor,
@@ -99,7 +137,7 @@ def forward(
99137
batch_size, seq_length = hidden_states.shape[:-1]
100138
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
101139
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
102-
breakpoint()
140+
103141
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
104142
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
105143

@@ -111,12 +149,15 @@ def forward(
111149

112150
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
113151
cos, sin = position_embeddings
114-
breakpoint()
115-
query_states, key_states = qeff_apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids)
116-
key_states = key_states.expand(*k_pass.shape[:-1], -1)
152+
if self.config.rope_interleave:
153+
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
154+
else:
155+
q_rot, k_rot = qeff_apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
156+
157+
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
117158

118-
# query_states = torch.cat((q_pass, q_rot), dim=-1)
119-
# key_states = torch.cat((k_pass, k_rot), dim=-1)
159+
query_states = torch.cat((q_pass, q_rot), dim=-1)
160+
key_states = torch.cat((k_pass, k_rot), dim=-1)
120161

121162
if past_key_value is not None:
122163
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
@@ -141,6 +182,8 @@ def forward(
141182
attn_weights = None
142183

143184
return attn_output, attn_weights, past_key_value
185+
186+
144187
class QEffDeepseekV3MoE(DeepseekV3MoE):
145188
def forward(self, hidden_states):
146189
residuals = hidden_states
@@ -159,21 +202,28 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig
159202
for expert_idx in range(len(self.experts)):
160203
expert = self.experts[expert_idx]
161204
mask = expert_mask[expert_idx]
162-
token_indices, weight_indices = torch.where(mask)
163-
164-
if token_indices.numel() > 0:
165-
expert_weights = topk_weights[token_indices, weight_indices]
166-
expert_input = hidden_states[token_indices]
167-
expert_output = expert(expert_input)
168-
weighted_output = expert_output * expert_weights.unsqueeze(-1)
169-
final_hidden_states.index_add_(0, token_indices, weighted_output)
205+
# token_indices, weight_indices = torch.where(mask)
206+
207+
# if token_indices.numel() > 0:
208+
if torch.sum(mask).item() > 0:
209+
# expert_weights = topk_weights[token_indices, weight_indices]
210+
# expert_input = hidden_states[token_indices]
211+
# expert_output = expert(expert_input)
212+
expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None])
213+
# weighted_output = expert_output * expert_weights.unsqueeze(-1)
214+
# final_hidden_states.index_add_(0, token_indices, weighted_output)
215+
expert_output = torch.where(
216+
(topk_weights * mask).sum(1).to(torch.bool)[:, None],
217+
expert_output,
218+
torch.tensor(0.0),
219+
)
220+
final_hidden_states = final_hidden_states + expert_output
170221
return final_hidden_states.type(hidden_states.dtype)
171-
172-
class QEffDeepseekV3DecoderLayer(DeepseekV3DecoderLayer):
173-
174222

175-
223+
224+
class QEffDeepseekV3DecoderLayer(DeepseekV3DecoderLayer):
176225
"""Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling."""
226+
177227
def forward(
178228
self,
179229
hidden_states: torch.Tensor,
@@ -217,8 +267,10 @@ def forward(
217267

218268
return outputs
219269

270+
220271
class QEffDeepseekV3Model(DeepseekV3Model):
221272
"""Adapted DeepseekV3Model with batch_index and QEff rotary embedding."""
273+
222274
def __init__(self, config: DeepseekV3Config):
223275
super().__init__(config)
224276
self.__qeff_init__()
@@ -241,6 +293,7 @@ def forward(
241293
cache_position: Optional[torch.LongTensor] = None,
242294
**kwargs,
243295
) -> Union[Tuple, BaseModelOutputWithPast]:
296+
breakpoint()
244297
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
245298
output_hidden_states = (
246299
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -252,18 +305,24 @@ def forward(
252305
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
253306

254307
if self.gradient_checkpointing and self.training and use_cache:
255-
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
308+
logger.warning_once(
309+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
310+
)
256311
use_cache = False
257312

258313
if inputs_embeds is None:
259314
inputs_embeds = self.embed_tokens(input_ids)
260315

261316
if use_cache and not isinstance(past_key_values, Cache):
262-
past_key_values = DynamicCache() if past_key_values is None else DynamicCache.from_legacy_cache(past_key_values)
317+
past_key_values = (
318+
DynamicCache() if past_key_values is None else DynamicCache.from_legacy_cache(past_key_values)
319+
)
263320

264321
if cache_position is None:
265322
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
266-
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
323+
cache_position = torch.arange(
324+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
325+
)
267326

268327
if position_ids is None:
269328
position_ids = cache_position.unsqueeze(0)
@@ -320,16 +379,17 @@ def forward(
320379
all_hidden_states += (hidden_states,)
321380

322381
next_cache = next_decoder_cache if use_cache else None
382+
next_cache = next_cache.to_legacy_cache()
323383
if not return_dict:
324384
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
325-
385+
326386
return BaseModelOutputWithPast(
327387
last_hidden_state=hidden_states,
328388
past_key_values=next_cache,
329389
hidden_states=all_hidden_states,
330390
attentions=all_self_attns,
331391
)
332-
392+
333393
def _update_causal_mask(
334394
self,
335395
attention_mask: torch.Tensor,
@@ -360,7 +420,7 @@ def _update_causal_mask(
360420
):
361421
return None
362422

363-
dtype, device = input_tensor.dtype, input_tensor.device
423+
dtype, _ = input_tensor.dtype, input_tensor.device
364424
sequence_length = input_tensor.shape[1]
365425
if using_static_cache:
366426
target_length = past_key_values.get_max_cache_shape()
@@ -397,8 +457,10 @@ def _update_causal_mask(
397457

398458
return causal_mask
399459

460+
400461
class QEffDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
401462
"""Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations."""
463+
402464
def forward(
403465
self,
404466
input_ids: torch.LongTensor = None,
@@ -461,4 +523,4 @@ def forward(
461523
past_key_values=outputs.past_key_values,
462524
hidden_states=outputs.hidden_states,
463525
attentions=outputs.attentions,
464-
)
526+
)

0 commit comments

Comments
 (0)