Skip to content

Commit b3aea52

Browse files
committed
Support for GQA and Llama2-70b
1 parent e61d4d3 commit b3aea52

File tree

8 files changed

+94
-61
lines changed

8 files changed

+94
-61
lines changed

README.md

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,12 @@ WikiText, so scores are not necessarily comparable to other Llama benchmarks.
161161
Since many seem to be interested in running 65B models, I can confirm that this works with two 24 GB GPUs. The
162162
following benchmarks are from a 4090 + 3090-Ti with `-gs 17.2,24`:
163163

164-
| Model | Size | groupsize | act | Seq. len. | VRAM | Prompt | Best | Worst | Ppl |
165-
|----------|------|-----------|-----|----------------------|-----------|-----------|--------|--------|------|
166-
| Llama | 65B | 128 | yes | 2,048 t | 39,804 MB | 1,109 t/s | 20 t/s | 18 t/s | 4.20 |
167-
| Llama | 65B | 32 | yes | 2,048 t | 43,424 MB | 1,037 t/s | 17 t/s | 16 t/s | 4.11 |
164+
| Model | Size | groupsize | act | Seq. len. | VRAM | Prompt | Best | Worst | Ppl |
165+
|---------|------|-----------|-----|----------------|-----------|-----------|--------|---------|-------|
166+
| Llama | 65B | 128 | yes | 2,048 t | 39,804 MB | 1,109 t/s | 20 t/s | 18 t/s | 4.20 |
167+
| Llama | 65B | 32 | yes | 2,048 t | 43,424 MB | 1,037 t/s | 17 t/s | 16 t/s | 4.11 |
168+
| Llama-2 | 70B | 128 | yes | 2,048 t | 40,680 MB | 1,037 t/s | 17 t/s | 14 t/s | 4.15 |
169+
| Llama-2 | 70B | 32 | yes | 2,048 t | 36,815 MB | 1,037 t/s | 15 t/s | 12 t/s | 4.10 |
168170

169171

170172
### Testing long sequences
@@ -192,28 +194,6 @@ confirmed to be working right now.
192194

193195
## Recent updates
194196

195-
**2023-06-02**: Web UI is now in a fairly working state. Expect it to be a little scuffed in places. There will be a
196-
rewrite at some point to make the client-side code less seizure-inducing. It has multibot mode, chat rewind and editing
197-
features, sessions, and more. I'm going to build it out with support for instruct prompting and such, in time.
198-
199-
**2023-06-04**: Refactored a whole bunch to move more of the work into the extension, setting up for more tuning
200-
options to come soon and eventually auto tuning. Also optimized a little, for about a 5% speedup.
201-
202-
**2023-06-06**: Some minor optimizations. Also it should now compile the extension more easily and run more seamlessly
203-
on Windows.
204-
205-
**2023-06-09**: Fused most of the self-attention step. More to come. Slight speedup already, but more importantly went
206-
from 69% actual CPU utilization to 37%. This should do a lot to address the bottleneck on CPUs with lower
207-
single-threaded performance.
208-
209-
**2023-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit.
210-
211-
**2023-06-11**: Added some concurrency a couple of places. It's only beneficial on the 4090, on small models where the
212-
cores are somewhat underutilized and the L2 cache can keep up. For the 3090 it's detrimental to performance, so it's
213-
disabled by default. YMMV. Use `-cs` to try it out.
214-
215-
**2023-06-17**: Fixed a nasty bug in the fused attention that was causing slightly incorrect cache states on 13B and
216-
33B models. You definitely want to update.
217-
218-
**2023-06-18**: LoRA support now. Still needs a lot of testing and some optimization, and currently you can't stack
219-
multiple LoRAs during the same inference. There's also no support in the web UI yet.
197+
**2023-07-19**: Added support for grouped-query attention and Llama-2 70b. There's still a bit of optimization to do,
198+
since it slows down considerably on very long sequences despite GQA having the potential to be faster. Also could use
199+
some more thorough testing.

exllama_ext/cuda_func/q4_attn.cu

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const int THREADS_X = 32;
1414
const int THREADS_Y = 1;
1515
const int THREADS_Z = 4;
1616
const int BLOCKSIZE_X = 2; // 2*half == 1*uint32_t
17-
const int BLOCKSIZE_Z = 4; // num_heads must be divisible by BLOCKSIZE_Z
17+
const int BLOCKSIZE_Z = 4; // num_heads must be divisible by BLOCKSIZE_Z TODO: Check that this is the case when Llama2-34b releases
1818

1919
__global__ void update_cache_kernel
2020
(
@@ -23,21 +23,21 @@ __global__ void update_cache_kernel
2323
half* __restrict__ key_cache,
2424
half* __restrict__ value_cache,
2525
const int head_dim,
26-
const int num_heads,
26+
const int num_kv_heads,
2727
const int q_len,
2828
const int max_seq_len,
2929
const int past_len
3030
)
3131
{
32-
//int state_shape[] = { num_heads, q_len, head_dim };
33-
int state_stride[] = { head_dim, head_dim * num_heads, 1 };
34-
int state_pos[] = { 0, 0, 0 };
32+
//int state_shape[] = { num_kv_heads, q_len, head_dim };
33+
int state_stride[] = { head_dim, head_dim * num_kv_heads, 1 };
34+
int state_pos[] = { 0, 0, 0 };
3535

36-
//int cache_shape[] = { num_heads, max_seq_len, head_dim };
37-
int cache_stride[] = { max_seq_len * head_dim, head_dim, 1 };
38-
int cache_pos[] = { 0, past_len, 0 };
36+
//int cache_shape[] = { num_kv_heads, max_seq_len, head_dim };
37+
int cache_stride[] = { max_seq_len * head_dim, head_dim, 1 };
38+
int cache_pos[] = { 0, past_len, 0 };
3939

40-
int size[] = { num_heads, q_len, head_dim };
40+
int size[] = { num_kv_heads, q_len, head_dim };
4141

4242
int x = (blockIdx.x * THREADS_X + threadIdx.x) * BLOCKSIZE_X;
4343
int y = blockIdx.y * THREADS_Y + threadIdx.y;
@@ -92,6 +92,7 @@ void q4_attn_cuda
9292
const int dim,
9393
const int head_dim,
9494
const int num_heads,
95+
const int num_kv_heads,
9596
const int past_len,
9697
half* key_cache,
9798
half* value_cache,
@@ -117,10 +118,11 @@ void q4_attn_cuda
117118
(
118119
((head_dim + THREADS_X - 1) / THREADS_X + BLOCKSIZE_X - 1) / BLOCKSIZE_X,
119120
q_len,
120-
((num_heads + THREADS_Z - 1) / THREADS_Z + BLOCKSIZE_Z - 1) / BLOCKSIZE_Z
121+
((num_kv_heads + THREADS_Z - 1) / THREADS_Z + BLOCKSIZE_Z - 1) / BLOCKSIZE_Z
121122
);
122123

123124
int _rows_per_batch = q_len * num_heads;
125+
int _rows_per_batch_kv = q_len * num_kv_heads;
124126

125127
CudaBuffers* buffers = get_buffers(device_index);
126128

@@ -158,11 +160,11 @@ void q4_attn_cuda
158160
// Positional embeddings q, k
159161

160162
rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len);
161-
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len);
163+
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len);
162164

163165
// Update cache tensors with projected k, v
164166

165-
update_cache_kernel<<<blocks, threads>>>(key_states, value_states, key_cache, value_cache, head_dim, num_heads, q_len, max_seq_len, past_len);
167+
update_cache_kernel<<<blocks, threads>>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len);
166168
}
167169
else
168170
{
@@ -178,20 +180,20 @@ void q4_attn_cuda
178180
// str_1: project q, positions q, sync
179181

180182
q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, q_a ? true : false, str_1);
181-
rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len, str_1);
183+
rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_kv_heads, past_len, str_1);
182184
cudaEventRecord(sync_1, str_1);
183185

184186
// str_2: project k, positions k, sync
185187

186188
q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, k_a ? true : false, str_2);
187-
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len, str_2);
189+
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len, str_2);
188190
cudaEventRecord(sync_2, str_2);
189191

190192
// str_3: project v, wait for str_2, copy (k,v) to cache, sync
191193

192194
q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, v_a ? true : false, buffers->alt_stream_3);
193195
cudaStreamWaitEvent(str_3, sync_2, 0);
194-
update_cache_kernel<<<blocks, threads, 0, str_3>>>(key_states, value_states, key_cache, value_cache, head_dim, num_heads, q_len, max_seq_len, past_len);
196+
update_cache_kernel<<<blocks, threads, 0, str_3>>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len);
195197
cudaEventRecord(sync_3, str_3);
196198

197199
// default: wait for str_1 and str_3

exllama_ext/cuda_func/q4_attn.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void q4_attn_cuda
2929
const int dim,
3030
const int head_dim,
3131
const int num_heads,
32+
const int num_kv_heads,
3233
const int past_len,
3334
half* key_cache,
3435
half* value_cache,

exllama_ext/exllama_ext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ void q4_attn
437437
int q_len,
438438
int past_len,
439439
int num_heads,
440+
int num_kv_heads,
440441
int head_dim,
441442
torch::Tensor key_cache,
442443
torch::Tensor value_cache,
@@ -488,6 +489,7 @@ void q4_attn
488489
dim,
489490
head_dim,
490491
num_heads,
492+
num_kv_heads,
491493
past_len,
492494
(half*) key_cache.data_ptr(),
493495
(half*) value_cache.data_ptr(),

model.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def __init__(self, model_config_path):
5454
self.rms_norm_eps = read_config["rms_norm_eps"]
5555
self.vocab_size = read_config["vocab_size"]
5656

57+
if "num_key_value_heads" in read_config:
58+
self.num_key_value_heads = read_config["num_key_value_heads"]
59+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
60+
else:
61+
self.num_key_value_heads = self.num_attention_heads
62+
self.num_key_value_groups = 1
63+
5764
self.rotary_embedding_base = 10000 # Constant used for pretrained models, leave as is unless retraining
5865
self.head_dim = self.hidden_size // self.num_attention_heads
5966

@@ -288,11 +295,23 @@ def __init__(self, config, tensors, key, sin, cos, index):
288295
self.index = index
289296

290297
self.q_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".q_proj")
291-
self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".k_proj")
292-
self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".v_proj")
298+
self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".k_proj")
299+
self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".v_proj")
293300
self.o_proj = Ex4bitLinear(config, self.config.num_attention_heads * self.config.head_dim, self.config.hidden_size, False, tensors, key + ".o_proj")
294301

295302

303+
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
304+
305+
# TODO: This seems inefficient. It should be possible to broadcast in the attention matmul to avoid building
306+
# temporary K/V tensors like this
307+
308+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
309+
if n_rep == 1: return hidden_states
310+
311+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
312+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
313+
314+
296315
def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
297316

298317
bsz, q_len, _ = hidden_states.size()
@@ -315,9 +334,9 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
315334

316335
# Project q, k, v, apply position embeddings to k and v, update cache
317336

318-
query_states = torch.empty((bsz, q_len, self.config.hidden_size), dtype = torch.float16, device = hidden_states.device)
319-
key_states = torch.empty((bsz, q_len, self.config.hidden_size), dtype = torch.float16, device = hidden_states.device)
320-
value_states = torch.empty((bsz, q_len, self.config.hidden_size), dtype = torch.float16, device = hidden_states.device)
337+
query_states = torch.empty((bsz, q_len, self.config.num_attention_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
338+
key_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
339+
value_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
321340

322341
cuda_ext.exllama_ext.q4_attn(hidden_states,
323342
input_layernorm.weight,
@@ -333,6 +352,7 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
333352
q_len,
334353
past_len,
335354
self.config.num_attention_heads,
355+
self.config.num_key_value_heads,
336356
self.config.head_dim,
337357
cache.key_states[self.index],
338358
cache.value_states[self.index],
@@ -349,11 +369,16 @@ def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
349369
key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len)
350370
value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len)
351371

372+
# Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads
373+
374+
query_states.transpose_(1, 2)
375+
key_states = self.repeat_kv(key_states, self.config.num_key_value_groups)
376+
value_states = self.repeat_kv(value_states, self.config.num_key_value_groups)
377+
352378
# Attention
353379
# TODO: Figure out if we can use cublasHgemmStridedBatched() to do this matmul without reshaping. Torch uses
354380
# gemmStridedBatchedEx() internally, so it should be possible.
355381

356-
query_states.transpose_(1, 2)
357382
key_states.transpose_(2, 3)
358383
attn_weights = torch.matmul(query_states, key_states)
359384
attn_weights /= math.sqrt(self.config.head_dim)
@@ -383,11 +408,11 @@ def forward(self, hidden_states, cache, buffer, lora):
383408
key_states = self.k_proj.forward(hidden_states, lora)
384409

385410
cuda_ext.exllama_ext.rope_(query_states, self.sin, self.cos, past_len, self.config.num_attention_heads, self.config.head_dim)
386-
cuda_ext.exllama_ext.rope_(key_states, self.sin, self.cos, past_len, self.config.num_attention_heads, self.config.head_dim)
411+
cuda_ext.exllama_ext.rope_(key_states, self.sin, self.cos, past_len, self.config.num_key_value_heads, self.config.head_dim)
387412

388413
query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
389-
key_states = key_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
390-
value_states = self.v_proj.forward(hidden_states, lora).view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
414+
key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2)
415+
value_states = self.v_proj.forward(hidden_states, lora).view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2)
391416

392417
# Add keys and values to cache
393418

@@ -401,6 +426,11 @@ def forward(self, hidden_states, cache, buffer, lora):
401426
key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len)
402427
value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len)
403428

429+
# Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads
430+
431+
key_states = self.repeat_kv(key_states, self.config.num_key_value_groups)
432+
value_states = self.repeat_kv(value_states, self.config.num_key_value_groups)
433+
404434
# Attention
405435

406436
# -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage
@@ -508,8 +538,8 @@ def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
508538

509539
if copy_from is None:
510540

511-
p_key_states = torch.zeros(self.batch_size, self.config.num_attention_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
512-
p_value_states = torch.zeros(self.batch_size, self.config.num_attention_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
541+
p_key_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
542+
p_value_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
513543

514544
else:
515545

@@ -520,6 +550,13 @@ def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
520550
self.value_states.append(p_value_states)
521551

522552

553+
def zero(self):
554+
555+
for i in range(self.config.num_hidden_layers):
556+
self.key_states[i].zero_()
557+
self.value_states[i].zero_()
558+
559+
523560
def clone(self):
524561

525562
new = ExLlamaCache(self.model, batch_size = self.batch_size, max_seq_len = self.max_seq_len, copy_from = self)

perplexity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
'''
1515

1616
class Perplexity:
17-
def __init__(self, method="default", model=None, cache=None, tokenizer=None):
17+
def __init__(self, method="default", model = None, cache = None, tokenizer = None):
1818
# This needs to be loaded by calling .load()
1919
self.dataset_chunks = []
2020

@@ -36,7 +36,7 @@ def _next_logits(self, input_ids, apply_lora, last_id_only = True):
3636
# n_logits = []
3737
# a = 0
3838
# while a < input_ids.shape[-1]:
39-
# b = min(input_ids.shape[-1], a + 2048) # TODO: Should this be a config parameter?
39+
# b = min(input_ids.shape[-1], a + 2048)
4040
# n_logits.append(self.model.forward(input_ids[:, a:b], self.cache, last_id_only, lora = apply_lora))
4141
# a = b
4242
#

test_benchmark_inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def mem(name, total = False):
129129
torch.cuda.reset_peak_memory_stats("cuda")
130130
mem("Model")
131131

132+
cache = ExLlamaCache(model)
133+
mem("Cache")
134+
132135
# Load LoRA
133136

134137
lora = None
@@ -230,8 +233,10 @@ def mem(name, total = False):
230233

231234
begin()
232235

236+
ppl.cache.zero()
233237
model.config.matmul_recons_thd = 1
234238
ppl.test(8, lora = lora, tag = " (reconstruct)")
239+
ppl.cache.zero()
235240
model.config.matmul_recons_thd = 0
236241
ppl.test(8, lora = lora, tag = " (quant, token)", ppl_token = True)
237242

tokenizer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@ def __init__(self, tokenizer_model_path):
88

99
self.path = tokenizer_model_path
1010
self.tokenizer = SentencePieceProcessor(model_file = self.path)
11+
12+
self.unk_token = "<unk>"
13+
self.bos_token = "<s>"
14+
self.eos_token = "</s>"
15+
self.unk_token_id = self.tokenizer.unk_id()
1116
self.eos_token_id = self.tokenizer.eos_id()
1217
self.bos_token_id = self.tokenizer.bos_id()
13-
self.pad_token_id = 0
18+
self.pad_token_id = 0 # self.tokenizer.pad_id()
1419
self.newline_token_id = 13
1520

21+
1622
# Encode string
1723

1824
def encode(self, text):
@@ -21,22 +27,22 @@ def encode(self, text):
2127

2228
# text is a list of strings
2329

24-
list_ids = self.tokenizer.Encode(text)
30+
list_ids = self.tokenizer.EncodeAsIds(text)
2531
max_length = max([len(ids) for ids in list_ids])
2632

2733
padded_ids = []
2834
for ids in list_ids:
2935
padding = torch.full((max_length - len(ids),), self.pad_token_id)
3036
sequence = torch.tensor(ids)
31-
padded_ids.append(torch.cat((padding, sequence), dim = 0))
37+
padded_ids.append(torch.cat((padding, sequence), dim = 0).long())
3238

3339
return torch.stack(padded_ids, dim = 0)
3440

3541
else:
3642

3743
# text is a single string
3844

39-
ids = self.tokenizer.Encode(text)
45+
ids = self.tokenizer.EncodeAsIds(text)
4046
return torch.tensor(ids).unsqueeze(0)
4147

4248
def decode(self, ids):

0 commit comments

Comments
 (0)