@@ -54,6 +54,13 @@ def __init__(self, model_config_path):
54
54
self .rms_norm_eps = read_config ["rms_norm_eps" ]
55
55
self .vocab_size = read_config ["vocab_size" ]
56
56
57
+ if "num_key_value_heads" in read_config :
58
+ self .num_key_value_heads = read_config ["num_key_value_heads" ]
59
+ self .num_key_value_groups = self .num_attention_heads // self .num_key_value_heads
60
+ else :
61
+ self .num_key_value_heads = self .num_attention_heads
62
+ self .num_key_value_groups = 1
63
+
57
64
self .rotary_embedding_base = 10000 # Constant used for pretrained models, leave as is unless retraining
58
65
self .head_dim = self .hidden_size // self .num_attention_heads
59
66
@@ -288,11 +295,23 @@ def __init__(self, config, tensors, key, sin, cos, index):
288
295
self .index = index
289
296
290
297
self .q_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".q_proj" )
291
- self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".k_proj" )
292
- self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".v_proj" )
298
+ self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_key_value_heads * self .config .head_dim , False , tensors , key + ".k_proj" )
299
+ self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_key_value_heads * self .config .head_dim , False , tensors , key + ".v_proj" )
293
300
self .o_proj = Ex4bitLinear (config , self .config .num_attention_heads * self .config .head_dim , self .config .hidden_size , False , tensors , key + ".o_proj" )
294
301
295
302
303
+ def repeat_kv (self , hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
304
+
305
+ # TODO: This seems inefficient. It should be possible to broadcast in the attention matmul to avoid building
306
+ # temporary K/V tensors like this
307
+
308
+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
309
+ if n_rep == 1 : return hidden_states
310
+
311
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
312
+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
313
+
314
+
296
315
def fused (self , hidden_states , cache , buffer , input_layernorm , lora ):
297
316
298
317
bsz , q_len , _ = hidden_states .size ()
@@ -315,9 +334,9 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
315
334
316
335
# Project q, k, v, apply position embeddings to k and v, update cache
317
336
318
- query_states = torch .empty ((bsz , q_len , self .config .hidden_size ), dtype = torch .float16 , device = hidden_states .device )
319
- key_states = torch .empty ((bsz , q_len , self .config .hidden_size ), dtype = torch .float16 , device = hidden_states .device )
320
- value_states = torch .empty ((bsz , q_len , self .config .hidden_size ), dtype = torch .float16 , device = hidden_states .device )
337
+ query_states = torch .empty ((bsz , q_len , self .config .num_attention_heads * self . config . head_dim ), dtype = torch .float16 , device = hidden_states .device )
338
+ key_states = torch .empty ((bsz , q_len , self .config .num_key_value_heads * self . config . head_dim ), dtype = torch .float16 , device = hidden_states .device )
339
+ value_states = torch .empty ((bsz , q_len , self .config .num_key_value_heads * self . config . head_dim ), dtype = torch .float16 , device = hidden_states .device )
321
340
322
341
cuda_ext .exllama_ext .q4_attn (hidden_states ,
323
342
input_layernorm .weight ,
@@ -333,6 +352,7 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
333
352
q_len ,
334
353
past_len ,
335
354
self .config .num_attention_heads ,
355
+ self .config .num_key_value_heads ,
336
356
self .config .head_dim ,
337
357
cache .key_states [self .index ],
338
358
cache .value_states [self .index ],
@@ -349,11 +369,16 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
349
369
key_states = cache .key_states [self .index ].narrow (2 , 0 , past_len + q_len )
350
370
value_states = cache .value_states [self .index ].narrow (2 , 0 , past_len + q_len )
351
371
372
+ # Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads
373
+
374
+ query_states .transpose_ (1 , 2 )
375
+ key_states = self .repeat_kv (key_states , self .config .num_key_value_groups )
376
+ value_states = self .repeat_kv (value_states , self .config .num_key_value_groups )
377
+
352
378
# Attention
353
379
# TODO: Figure out if we can use cublasHgemmStridedBatched() to do this matmul without reshaping. Torch uses
354
380
# gemmStridedBatchedEx() internally, so it should be possible.
355
381
356
- query_states .transpose_ (1 , 2 )
357
382
key_states .transpose_ (2 , 3 )
358
383
attn_weights = torch .matmul (query_states , key_states )
359
384
attn_weights /= math .sqrt (self .config .head_dim )
@@ -383,11 +408,11 @@ def forward(self, hidden_states, cache, buffer, lora):
383
408
key_states = self .k_proj .forward (hidden_states , lora )
384
409
385
410
cuda_ext .exllama_ext .rope_ (query_states , self .sin , self .cos , past_len , self .config .num_attention_heads , self .config .head_dim )
386
- cuda_ext .exllama_ext .rope_ (key_states , self .sin , self .cos , past_len , self .config .num_attention_heads , self .config .head_dim )
411
+ cuda_ext .exllama_ext .rope_ (key_states , self .sin , self .cos , past_len , self .config .num_key_value_heads , self .config .head_dim )
387
412
388
413
query_states = query_states .view (bsz , q_len , self .config .num_attention_heads , self .config .head_dim ).transpose (1 , 2 )
389
- key_states = key_states .view (bsz , q_len , self .config .num_attention_heads , self .config .head_dim ).transpose (1 , 2 )
390
- value_states = self .v_proj .forward (hidden_states , lora ).view (bsz , q_len , self .config .num_attention_heads , self .config .head_dim ).transpose (1 , 2 )
414
+ key_states = key_states .view (bsz , q_len , self .config .num_key_value_heads , self .config .head_dim ).transpose (1 , 2 )
415
+ value_states = self .v_proj .forward (hidden_states , lora ).view (bsz , q_len , self .config .num_key_value_heads , self .config .head_dim ).transpose (1 , 2 )
391
416
392
417
# Add keys and values to cache
393
418
@@ -401,6 +426,11 @@ def forward(self, hidden_states, cache, buffer, lora):
401
426
key_states = cache .key_states [self .index ].narrow (2 , 0 , past_len + q_len )
402
427
value_states = cache .value_states [self .index ].narrow (2 , 0 , past_len + q_len )
403
428
429
+ # Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads
430
+
431
+ key_states = self .repeat_kv (key_states , self .config .num_key_value_groups )
432
+ value_states = self .repeat_kv (value_states , self .config .num_key_value_groups )
433
+
404
434
# Attention
405
435
406
436
# -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage
@@ -508,8 +538,8 @@ def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
508
538
509
539
if copy_from is None :
510
540
511
- p_key_states = torch .zeros (self .batch_size , self .config .num_attention_heads , self .max_seq_len , self .config .head_dim , dtype = torch .float16 , device = self .model .config .device_map .layers [i ])
512
- p_value_states = torch .zeros (self .batch_size , self .config .num_attention_heads , self .max_seq_len , self .config .head_dim , dtype = torch .float16 , device = self .model .config .device_map .layers [i ])
541
+ p_key_states = torch .zeros (self .batch_size , self .config .num_key_value_heads , self .max_seq_len , self .config .head_dim , dtype = torch .float16 , device = self .model .config .device_map .layers [i ])
542
+ p_value_states = torch .zeros (self .batch_size , self .config .num_key_value_heads , self .max_seq_len , self .config .head_dim , dtype = torch .float16 , device = self .model .config .device_map .layers [i ])
513
543
514
544
else :
515
545
@@ -520,6 +550,13 @@ def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
520
550
self .value_states .append (p_value_states )
521
551
522
552
553
+ def zero (self ):
554
+
555
+ for i in range (self .config .num_hidden_layers ):
556
+ self .key_states [i ].zero_ ()
557
+ self .value_states [i ].zero_ ()
558
+
559
+
523
560
def clone (self ):
524
561
525
562
new = ExLlamaCache (self .model , batch_size = self .batch_size , max_seq_len = self .max_seq_len , copy_from = self )
0 commit comments