1
- import math
2
1
from typing import List , Optional , Tuple , Union
3
2
4
3
import torch
11
10
12
11
# Assuming these are imported from the original DeepseekV3 code
13
12
from transformers .models .deepseek_v3 .modeling_deepseek_v3 import (
14
- DeepseekV3Config ,
15
- DeepseekV3RMSNorm ,
16
- DeepseekV3MLP ,
17
- DeepseekV3MoE ,
18
- rotate_half ,
19
- repeat_kv ,
20
13
DeepseekV3Attention ,
14
+ DeepseekV3Config ,
21
15
DeepseekV3DecoderLayer ,
22
- DeepseekV3Model ,
23
16
DeepseekV3ForCausalLM ,
24
- DeepseekV3PreTrainedModel ,
17
+ DeepseekV3Model ,
18
+ DeepseekV3MoE ,
25
19
logger ,
20
+ repeat_kv ,
21
+ rotate_half ,
26
22
)
23
+
27
24
from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
28
25
26
+
29
27
class QEffDeepseekV3RotaryEmbedding (nn .Module ):
30
28
"""
31
29
Adapted from DeepseekV3RotaryEmbedding with static sin/cos caches like QEffLlamaRotaryEmbedding.
32
30
"""
31
+
33
32
def __init__ (self , config : DeepseekV3Config , device = None ):
34
33
super ().__init__ ()
35
34
if config .rope_scaling is not None :
@@ -46,9 +45,7 @@ def __init__(self, config: DeepseekV3Config, device=None):
46
45
47
46
# Precompute static sin/cos caches
48
47
self ._set_cos_sin_cache (
49
- seq_len = self .original_max_seq_len ,
50
- device = self .inv_freq .device ,
51
- dtype = torch .get_default_dtype ()
48
+ seq_len = self .original_max_seq_len , device = self .inv_freq .device , dtype = torch .get_default_dtype ()
52
49
)
53
50
54
51
def _set_cos_sin_cache (self , seq_len , device , dtype ):
@@ -63,12 +60,51 @@ def forward(self, x, position_ids):
63
60
seq_len = torch .max (position_ids ) + 1
64
61
if seq_len > self .max_seq_len_cached :
65
62
self ._set_cos_sin_cache (seq_len = seq_len , device = x .device , dtype = x .dtype )
66
-
63
+
67
64
# Use position_ids to slice the precomputed caches
68
65
cos = self .cos_cached [position_ids ] * self .attention_scaling
69
66
sin = self .sin_cached [position_ids ] * self .attention_scaling
70
67
return cos .to (x .dtype ), sin .to (x .dtype )
71
68
69
+
70
+ def apply_rotary_pos_emb_interleave (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
71
+ """Applies Rotary Position Embedding to the query and key tensors.
72
+ Args:
73
+ q (`torch.Tensor`): The query tensor.
74
+ k (`torch.Tensor`): The key tensor.
75
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
76
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
77
+ position_ids (`torch.Tensor`):
78
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
79
+ used to pass offsetted position ids when working with a KV-cache.
80
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
81
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
82
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
83
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
84
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
85
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
86
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
87
+ Returns:
88
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
89
+ """
90
+ # Slice cos and sin using position_ids if they are larger (e.g., precomputed caches)
91
+ if cos .shape [- 2 ] > q .shape [- 2 ]:
92
+ cos = cos [:, position_ids , :]
93
+ sin = sin [:, position_ids , :]
94
+ cos = cos .unsqueeze (unsqueeze_dim )
95
+ sin = sin .unsqueeze (unsqueeze_dim )
96
+
97
+ b , h , s , d = q .shape
98
+ q = q .view (b , h , s , d // 2 , 2 ).transpose (4 , 3 ).reshape (b , h , s , d )
99
+
100
+ b , h , s , d = k .shape
101
+ k = k .view (b , h , s , d // 2 , 2 ).transpose (4 , 3 ).reshape (b , h , s , d )
102
+
103
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
104
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
105
+ return q_embed , k_embed
106
+
107
+
72
108
def qeff_apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
73
109
"""Adapted from DeepseekV3's apply_rotary_pos_emb for QEff compatibility with position_ids slicing."""
74
110
# Slice cos and sin using position_ids if they are larger (e.g., precomputed caches)
@@ -81,8 +117,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
81
117
k_embed = (k * cos ) + (rotate_half (k ) * sin )
82
118
return q_embed .to (q .dtype ), k_embed .to (k .dtype )
83
119
120
+
84
121
class QEffDeepseekV3Attention (DeepseekV3Attention ):
85
122
"""Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling."""
123
+
86
124
def forward (
87
125
self ,
88
126
hidden_states : torch .Tensor ,
@@ -99,7 +137,7 @@ def forward(
99
137
batch_size , seq_length = hidden_states .shape [:- 1 ]
100
138
query_shape = (batch_size , seq_length , - 1 , self .qk_head_dim )
101
139
key_shape = (batch_size , seq_length , - 1 , self .qk_nope_head_dim + self .v_head_dim )
102
- breakpoint ()
140
+
103
141
q_states = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states ))).view (query_shape ).transpose (1 , 2 )
104
142
q_pass , q_rot = torch .split (q_states , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
105
143
@@ -111,12 +149,15 @@ def forward(
111
149
112
150
k_rot = k_rot .view (batch_size , 1 , seq_length , self .qk_rope_head_dim )
113
151
cos , sin = position_embeddings
114
- breakpoint ()
115
- query_states , key_states = qeff_apply_rotary_pos_emb (q_rot , k_rot , cos , sin , position_ids )
116
- key_states = key_states .expand (* k_pass .shape [:- 1 ], - 1 )
152
+ if self .config .rope_interleave :
153
+ q_rot , k_rot = apply_rotary_pos_emb_interleave (q_rot , k_rot , cos , sin )
154
+ else :
155
+ q_rot , k_rot = qeff_apply_rotary_pos_emb (q_rot , k_rot , cos , sin )
156
+
157
+ k_rot = k_rot .expand (* k_pass .shape [:- 1 ], - 1 )
117
158
118
- # query_states = torch.cat((q_pass, q_rot), dim=-1)
119
- # key_states = torch.cat((k_pass, k_rot), dim=-1)
159
+ query_states = torch .cat ((q_pass , q_rot ), dim = - 1 )
160
+ key_states = torch .cat ((k_pass , k_rot ), dim = - 1 )
120
161
121
162
if past_key_value is not None :
122
163
cache_kwargs = {"sin" : sin , "cos" : cos , "batch_index" : batch_index , "position_ids" : position_ids }
@@ -141,6 +182,8 @@ def forward(
141
182
attn_weights = None
142
183
143
184
return attn_output , attn_weights , past_key_value
185
+
186
+
144
187
class QEffDeepseekV3MoE (DeepseekV3MoE ):
145
188
def forward (self , hidden_states ):
146
189
residuals = hidden_states
@@ -159,21 +202,28 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig
159
202
for expert_idx in range (len (self .experts )):
160
203
expert = self .experts [expert_idx ]
161
204
mask = expert_mask [expert_idx ]
162
- token_indices , weight_indices = torch .where (mask )
163
-
164
- if token_indices .numel () > 0 :
165
- expert_weights = topk_weights [token_indices , weight_indices ]
166
- expert_input = hidden_states [token_indices ]
167
- expert_output = expert (expert_input )
168
- weighted_output = expert_output * expert_weights .unsqueeze (- 1 )
169
- final_hidden_states .index_add_ (0 , token_indices , weighted_output )
205
+ # token_indices, weight_indices = torch.where(mask)
206
+
207
+ # if token_indices.numel() > 0:
208
+ if torch .sum (mask ).item () > 0 :
209
+ # expert_weights = topk_weights[token_indices, weight_indices]
210
+ # expert_input = hidden_states[token_indices]
211
+ # expert_output = expert(expert_input)
212
+ expert_output = expert (hidden_states ) * (((topk_weights * mask ).sum (1 ))[:, None ])
213
+ # weighted_output = expert_output * expert_weights.unsqueeze(-1)
214
+ # final_hidden_states.index_add_(0, token_indices, weighted_output)
215
+ expert_output = torch .where (
216
+ (topk_weights * mask ).sum (1 ).to (torch .bool )[:, None ],
217
+ expert_output ,
218
+ torch .tensor (0.0 ),
219
+ )
220
+ final_hidden_states = final_hidden_states + expert_output
170
221
return final_hidden_states .type (hidden_states .dtype )
171
-
172
- class QEffDeepseekV3DecoderLayer (DeepseekV3DecoderLayer ):
173
-
174
222
175
-
223
+
224
+ class QEffDeepseekV3DecoderLayer (DeepseekV3DecoderLayer ):
176
225
"""Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling."""
226
+
177
227
def forward (
178
228
self ,
179
229
hidden_states : torch .Tensor ,
@@ -217,8 +267,10 @@ def forward(
217
267
218
268
return outputs
219
269
270
+
220
271
class QEffDeepseekV3Model (DeepseekV3Model ):
221
272
"""Adapted DeepseekV3Model with batch_index and QEff rotary embedding."""
273
+
222
274
def __init__ (self , config : DeepseekV3Config ):
223
275
super ().__init__ (config )
224
276
self .__qeff_init__ ()
@@ -241,6 +293,7 @@ def forward(
241
293
cache_position : Optional [torch .LongTensor ] = None ,
242
294
** kwargs ,
243
295
) -> Union [Tuple , BaseModelOutputWithPast ]:
296
+ breakpoint ()
244
297
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
245
298
output_hidden_states = (
246
299
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
@@ -252,18 +305,24 @@ def forward(
252
305
raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
253
306
254
307
if self .gradient_checkpointing and self .training and use_cache :
255
- logger .warning_once ("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." )
308
+ logger .warning_once (
309
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
310
+ )
256
311
use_cache = False
257
312
258
313
if inputs_embeds is None :
259
314
inputs_embeds = self .embed_tokens (input_ids )
260
315
261
316
if use_cache and not isinstance (past_key_values , Cache ):
262
- past_key_values = DynamicCache () if past_key_values is None else DynamicCache .from_legacy_cache (past_key_values )
317
+ past_key_values = (
318
+ DynamicCache () if past_key_values is None else DynamicCache .from_legacy_cache (past_key_values )
319
+ )
263
320
264
321
if cache_position is None :
265
322
past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
266
- cache_position = torch .arange (past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device )
323
+ cache_position = torch .arange (
324
+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
325
+ )
267
326
268
327
if position_ids is None :
269
328
position_ids = cache_position .unsqueeze (0 )
@@ -320,16 +379,17 @@ def forward(
320
379
all_hidden_states += (hidden_states ,)
321
380
322
381
next_cache = next_decoder_cache if use_cache else None
382
+ next_cache = next_cache .to_legacy_cache ()
323
383
if not return_dict :
324
384
return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ] if v is not None )
325
-
385
+
326
386
return BaseModelOutputWithPast (
327
387
last_hidden_state = hidden_states ,
328
388
past_key_values = next_cache ,
329
389
hidden_states = all_hidden_states ,
330
390
attentions = all_self_attns ,
331
391
)
332
-
392
+
333
393
def _update_causal_mask (
334
394
self ,
335
395
attention_mask : torch .Tensor ,
@@ -360,7 +420,7 @@ def _update_causal_mask(
360
420
):
361
421
return None
362
422
363
- dtype , device = input_tensor .dtype , input_tensor .device
423
+ dtype , _ = input_tensor .dtype , input_tensor .device
364
424
sequence_length = input_tensor .shape [1 ]
365
425
if using_static_cache :
366
426
target_length = past_key_values .get_max_cache_shape ()
@@ -397,8 +457,10 @@ def _update_causal_mask(
397
457
398
458
return causal_mask
399
459
460
+
400
461
class QEffDeepseekV3ForCausalLM (DeepseekV3ForCausalLM ):
401
462
"""Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations."""
463
+
402
464
def forward (
403
465
self ,
404
466
input_ids : torch .LongTensor = None ,
@@ -461,4 +523,4 @@ def forward(
461
523
past_key_values = outputs .past_key_values ,
462
524
hidden_states = outputs .hidden_states ,
463
525
attentions = outputs .attentions ,
464
- )
526
+ )
0 commit comments