10
10
from typing import Optional , Tuple , Union
11
11
12
12
import torch
13
- import torch .utils .checkpoint
14
13
from torch import nn
15
- from torch .nn import CrossEntropyLoss
16
- from transformers .cache_utils import Cache , DynamicCache
14
+ from transformers .cache_utils import Cache
17
15
from transformers .modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
18
16
from transformers .models .codegen .modeling_codegen import (
19
17
CodeGenAttention ,
20
18
CodeGenBlock ,
21
19
CodeGenForCausalLM ,
22
20
CodeGenModel ,
23
21
apply_rotary_pos_emb ,
24
- logger ,
25
22
)
26
23
24
+ from QEfficient .transformers .cache_utils import QEffDynamicCache
27
25
from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
28
26
29
27
@@ -133,7 +131,7 @@ def forward(
133
131
"position_ids" : position_ids ,
134
132
"batch_index" : batch_index ,
135
133
}
136
- pkv = DynamicCache ()
134
+ pkv = QEffDynamicCache ()
137
135
pkv .key_cache .append (past_key_value [0 ])
138
136
pkv .value_cache .append (past_key_value [1 ])
139
137
key , value = pkv .update (key , value , 0 , cache_kwargs )
@@ -261,14 +259,6 @@ def forward(
261
259
262
260
output_shape = input_shape + (hidden_states .size (- 1 ),)
263
261
264
- if self .gradient_checkpointing and self .training :
265
- if use_cache :
266
- logger .warning_once (
267
- "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
268
- "`use_cache=False`..."
269
- )
270
- use_cache = False
271
-
272
262
if position_ids is None :
273
263
position_ids = cache_position .unsqueeze (0 )
274
264
@@ -279,41 +269,17 @@ def forward(
279
269
if output_hidden_states :
280
270
all_hidden_states = all_hidden_states + (hidden_states ,)
281
271
282
- if self .gradient_checkpointing and self .training :
283
- outputs = self ._gradient_checkpointing_func (
284
- block .__call__ ,
285
- hidden_states ,
286
- None ,
287
- attention_mask ,
288
- position_ids ,
289
- head_mask [i ],
290
- use_cache ,
291
- output_attentions ,
292
- cache_position ,
293
- )
294
- elif batch_index is not None :
295
- outputs = block (
296
- hidden_states = hidden_states ,
297
- layer_past = layer_past ,
298
- batch_index = batch_index ,
299
- attention_mask = attention_mask ,
300
- position_ids = position_ids ,
301
- head_mask = head_mask [i ],
302
- use_cache = use_cache ,
303
- output_attentions = output_attentions ,
304
- cache_position = cache_position ,
305
- )
306
- else :
307
- outputs = block (
308
- hidden_states = hidden_states ,
309
- layer_past = layer_past ,
310
- attention_mask = attention_mask ,
311
- position_ids = position_ids ,
312
- head_mask = head_mask [i ],
313
- use_cache = use_cache ,
314
- output_attentions = output_attentions ,
315
- cache_position = cache_position ,
316
- )
272
+ outputs = block (
273
+ hidden_states = hidden_states ,
274
+ layer_past = layer_past ,
275
+ batch_index = batch_index ,
276
+ attention_mask = attention_mask ,
277
+ position_ids = position_ids ,
278
+ head_mask = head_mask [i ],
279
+ use_cache = use_cache ,
280
+ output_attentions = output_attentions ,
281
+ cache_position = cache_position ,
282
+ )
317
283
318
284
hidden_states = outputs [0 ]
319
285
if use_cache is True :
@@ -398,25 +364,8 @@ def forward(
398
364
hidden_states = transformer_outputs [0 ][torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
399
365
lm_logits = self .lm_head (hidden_states )
400
366
401
- loss = None
402
- if labels is not None :
403
- # move labels to correct device to enable model parallelism
404
- labels = labels .to (lm_logits .device )
405
- # Shift so that tokens < n predict n
406
- shift_logits = lm_logits [..., :- 1 , :].contiguous ()
407
- shift_labels = labels [..., 1 :].contiguous ()
408
- # Flatten the tokens
409
- loss_fct = CrossEntropyLoss ()
410
- loss = loss_fct (shift_logits .view (- 1 , shift_logits .size (- 1 )), shift_labels .view (- 1 ))
411
-
412
- loss = loss .to (hidden_states .dtype )
413
-
414
- if not return_dict :
415
- output = (lm_logits ,) + transformer_outputs [1 :]
416
- return ((loss ,) + output ) if loss is not None else output
417
-
418
367
return CausalLMOutputWithPast (
419
- loss = loss ,
368
+ loss = None ,
420
369
logits = lm_logits ,
421
370
past_key_values = transformer_outputs .past_key_values ,
422
371
hidden_states = transformer_outputs .hidden_states ,
0 commit comments