Skip to content

Commit 6952a46

Browse files
committed
llama : add cell_max heuristic for more efficient kv_cache
1 parent 9f42e75 commit 6952a46

File tree

2 files changed

+102
-29
lines changed

2 files changed

+102
-29
lines changed

llama.cpp

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,9 @@ struct llama_kv_cache {
10231023
uint32_t head = 0;
10241024
uint32_t size = 0;
10251025

1026+
// largest index of an occupied cell (used for a basic optimization heuristic)
1027+
uint32_t cell_max = 0;
1028+
10261029
std::vector<llama_kv_cell> cells;
10271030

10281031
struct ggml_tensor * k = NULL;
@@ -1226,6 +1229,8 @@ static bool llama_kv_cache_init(
12261229
cache.head = 0;
12271230
cache.size = n_ctx;
12281231

1232+
cache.cell_max = 0;
1233+
12291234
cache.cells.clear();
12301235
cache.cells.resize(n_ctx);
12311236

@@ -1311,6 +1316,16 @@ static bool llama_kv_cache_find_slot(
13111316
return true;
13121317
}
13131318

1319+
void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
1320+
cache.cell_max = 0;
1321+
1322+
for (uint32_t i = 0; i < cache.size; i++) {
1323+
if (cache.cells[i].pos >= 0) {
1324+
cache.cell_max = i + 1;
1325+
}
1326+
}
1327+
}
1328+
13141329
void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) {
13151330
cache.head = p0;
13161331

@@ -1321,6 +1336,8 @@ void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1)
13211336
cache.cells[i].pos = -1;
13221337
cache.cells[i].seq_id.clear();
13231338
}
1339+
1340+
llama_kv_cache_update_cell_max(cache);
13241341
}
13251342

13261343
//
@@ -2547,6 +2564,7 @@ static struct ggml_cgraph * llm_build_llama(
25472564
const int n_gpu_layers = model.n_gpu_layers;
25482565

25492566
const int32_t n_tokens = batch.n_tokens;
2567+
const int32_t n_kv = kv_self.cell_max + n_tokens;
25502568

25512569
auto & buf_compute = lctx.buf_compute;
25522570

@@ -2621,17 +2639,27 @@ static struct ggml_cgraph * llm_build_llama(
26212639
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
26222640

26232641
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
2624-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1);
2642+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
26252643
ggml_allocr_alloc(lctx.alloc, KQ_mask);
26262644
if (!ggml_allocr_is_measure(lctx.alloc)) {
26272645
float * data = (float *) KQ_mask->data;
26282646
memset(data, 0, ggml_nbytes(KQ_mask));
26292647

26302648
for (int h = 0; h < 1; ++h) {
26312649
for (int j = 0; j < n_tokens; ++j) {
2632-
for (int i = 0; i < n_ctx; ++i) {
2633-
if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) {
2634-
data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY;
2650+
const llama_pos pos = batch.pos[j];
2651+
const llama_seq_id seq_id = batch.seq_id[j];
2652+
2653+
for (int i = 0; i < n_kv; ++i) {
2654+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2655+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2656+
}
2657+
}
2658+
2659+
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
2660+
for (int i = n_kv; i < n_ctx; ++i) {
2661+
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
2662+
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
26352663
}
26362664
}
26372665
}
@@ -2725,7 +2753,7 @@ static struct ggml_cgraph * llm_build_llama(
27252753

27262754
struct ggml_tensor * K =
27272755
ggml_view_3d(ctx0, kv_self.k,
2728-
n_embd_head, n_ctx, n_head_kv,
2756+
n_embd_head, n_kv, n_head_kv,
27292757
ggml_element_size(kv_self.k)*n_embd_gqa,
27302758
ggml_element_size(kv_self.k)*n_embd_head,
27312759
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
@@ -2738,7 +2766,7 @@ static struct ggml_cgraph * llm_build_llama(
27382766
ggml_set_name(KQ, "KQ");
27392767

27402768
// KQ_scaled = KQ / sqrt(n_embd_head)
2741-
// KQ_scaled shape [n_ctx, n_tokens, n_head, 1]
2769+
// KQ_scaled shape [n_kv, n_tokens, n_head, 1]
27422770
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
27432771
offload_func_kq(KQ_scaled);
27442772
ggml_set_name(KQ_scaled, "KQ_scaled");
@@ -2756,7 +2784,7 @@ static struct ggml_cgraph * llm_build_llama(
27562784
// split cached V into n_head heads
27572785
struct ggml_tensor * V =
27582786
ggml_view_3d(ctx0, kv_self.v,
2759-
n_ctx, n_embd_head, n_head_kv,
2787+
n_kv, n_embd_head, n_head_kv,
27602788
ggml_element_size(kv_self.v)*n_ctx,
27612789
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
27622790
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
@@ -2901,6 +2929,7 @@ static struct ggml_cgraph * llm_build_baichaun(
29012929
const int n_gpu_layers = model.n_gpu_layers;
29022930

29032931
const int32_t n_tokens = batch.n_tokens;
2932+
const int32_t n_kv = kv_self.cell_max + n_tokens;
29042933

29052934
auto & buf_compute = lctx.buf_compute;
29062935

@@ -2975,17 +3004,27 @@ static struct ggml_cgraph * llm_build_baichaun(
29753004
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
29763005

29773006
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
2978-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1);
3007+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
29793008
ggml_allocr_alloc(lctx.alloc, KQ_mask);
29803009
if (!ggml_allocr_is_measure(lctx.alloc)) {
29813010
float * data = (float *) KQ_mask->data;
29823011
memset(data, 0, ggml_nbytes(KQ_mask));
29833012

29843013
for (int h = 0; h < 1; ++h) {
29853014
for (int j = 0; j < n_tokens; ++j) {
2986-
for (int i = 0; i < n_ctx; ++i) {
2987-
if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) {
2988-
data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY;
3015+
const llama_pos pos = batch.pos[j];
3016+
const llama_seq_id seq_id = batch.seq_id[j];
3017+
3018+
for (int i = 0; i < n_kv; ++i) {
3019+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
3020+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3021+
}
3022+
}
3023+
3024+
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3025+
for (int i = n_kv; i < n_ctx; ++i) {
3026+
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3027+
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
29893028
}
29903029
}
29913030
}
@@ -3092,7 +3131,7 @@ static struct ggml_cgraph * llm_build_baichaun(
30923131

30933132
struct ggml_tensor * K =
30943133
ggml_view_3d(ctx0, kv_self.k,
3095-
n_embd_head, n_ctx, n_head_kv,
3134+
n_embd_head, n_kv, n_head_kv,
30963135
ggml_element_size(kv_self.k)*n_embd_gqa,
30973136
ggml_element_size(kv_self.k)*n_embd_head,
30983137
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
@@ -3135,7 +3174,7 @@ static struct ggml_cgraph * llm_build_baichaun(
31353174
// split cached V into n_head heads
31363175
struct ggml_tensor * V =
31373176
ggml_view_3d(ctx0, kv_self.v,
3138-
n_ctx, n_embd_head, n_head_kv,
3177+
n_kv, n_embd_head, n_head_kv,
31393178
ggml_element_size(kv_self.v)*n_ctx,
31403179
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
31413180
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
@@ -3272,6 +3311,7 @@ static struct ggml_cgraph * llm_build_falcon(
32723311
const int n_gpu_layers = model.n_gpu_layers;
32733312

32743313
const int32_t n_tokens = batch.n_tokens;
3314+
const int32_t n_kv = kv_self.cell_max + n_tokens;
32753315

32763316
auto & buf_compute = lctx.buf_compute;
32773317

@@ -3346,17 +3386,27 @@ static struct ggml_cgraph * llm_build_falcon(
33463386
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
33473387

33483388
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
3349-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1);
3389+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
33503390
ggml_allocr_alloc(lctx.alloc, KQ_mask);
33513391
if (!ggml_allocr_is_measure(lctx.alloc)) {
33523392
float * data = (float *) KQ_mask->data;
33533393
memset(data, 0, ggml_nbytes(KQ_mask));
33543394

33553395
for (int h = 0; h < 1; ++h) {
33563396
for (int j = 0; j < n_tokens; ++j) {
3357-
for (int i = 0; i < n_ctx; ++i) {
3358-
if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) {
3359-
data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY;
3397+
const llama_pos pos = batch.pos[j];
3398+
const llama_seq_id seq_id = batch.seq_id[j];
3399+
3400+
for (int i = 0; i < n_kv; ++i) {
3401+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
3402+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3403+
}
3404+
}
3405+
3406+
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3407+
for (int i = n_kv; i < n_ctx; ++i) {
3408+
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3409+
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
33603410
}
33613411
}
33623412
}
@@ -3479,7 +3529,7 @@ static struct ggml_cgraph * llm_build_falcon(
34793529

34803530
struct ggml_tensor * K =
34813531
ggml_view_3d(ctx0, kv_self.k,
3482-
n_embd_head, n_ctx, n_head_kv,
3532+
n_embd_head, n_kv, n_head_kv,
34833533
ggml_element_size(kv_self.k)*n_embd_gqa,
34843534
ggml_element_size(kv_self.k)*n_embd_head,
34853535
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
@@ -3504,7 +3554,7 @@ static struct ggml_cgraph * llm_build_falcon(
35043554

35053555
struct ggml_tensor * V =
35063556
ggml_view_3d(ctx0, kv_self.v,
3507-
n_ctx, n_embd_head, n_head_kv,
3557+
n_kv, n_embd_head, n_head_kv,
35083558
ggml_element_size(kv_self.v)*n_ctx,
35093559
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
35103560
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
@@ -3598,6 +3648,7 @@ static struct ggml_cgraph * llm_build_starcoder(
35983648
const float norm_eps = hparams.f_norm_eps;
35993649

36003650
const int32_t n_tokens = batch.n_tokens;
3651+
const int32_t n_kv = kv_self.cell_max + n_tokens;
36013652

36023653
auto & buf_compute = lctx.buf_compute;
36033654

@@ -3664,17 +3715,27 @@ static struct ggml_cgraph * llm_build_starcoder(
36643715
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
36653716

36663717
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
3667-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1);
3718+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
36683719
ggml_allocr_alloc(lctx.alloc, KQ_mask);
36693720
if (!ggml_allocr_is_measure(lctx.alloc)) {
36703721
float * data = (float *) KQ_mask->data;
36713722
memset(data, 0, ggml_nbytes(KQ_mask));
36723723

36733724
for (int h = 0; h < 1; ++h) {
36743725
for (int j = 0; j < n_tokens; ++j) {
3675-
for (int i = 0; i < n_ctx; ++i) {
3676-
if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) {
3677-
data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY;
3726+
const llama_pos pos = batch.pos[j];
3727+
const llama_seq_id seq_id = batch.seq_id[j];
3728+
3729+
for (int i = 0; i < n_kv; ++i) {
3730+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
3731+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3732+
}
3733+
}
3734+
3735+
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3736+
for (int i = n_kv; i < n_ctx; ++i) {
3737+
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3738+
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
36783739
}
36793740
}
36803741
}
@@ -3727,7 +3788,7 @@ static struct ggml_cgraph * llm_build_starcoder(
37273788

37283789
struct ggml_tensor * K =
37293790
ggml_view_3d(ctx0, kv_self.k,
3730-
n_embd_head, n_ctx, n_head_kv,
3791+
n_embd_head, n_kv, n_head_kv,
37313792
ggml_element_size(kv_self.k)*n_embd_gqa,
37323793
ggml_element_size(kv_self.k)*n_embd_head,
37333794
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
@@ -3753,7 +3814,7 @@ static struct ggml_cgraph * llm_build_starcoder(
37533814
// split cached V into n_head heads
37543815
struct ggml_tensor * V =
37553816
ggml_view_3d(ctx0, kv_self.v,
3756-
n_ctx, n_embd_head, n_head_kv,
3817+
n_kv, n_embd_head, n_head_kv,
37573818
ggml_element_size(kv_self.v)*n_ctx,
37583819
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
37593820
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
@@ -3974,8 +4035,9 @@ static bool llama_eval_internal(
39744035
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
39754036
#endif
39764037

3977-
// update the kv ring buffer head
3978-
lctx.kv_self.head += n_tokens;
4038+
// update the kv ring buffer
4039+
lctx.kv_self.head += n_tokens;
4040+
lctx.kv_self.cell_max = std::max(lctx.kv_self.cell_max, lctx.kv_self.head);
39794041

39804042
#ifdef GGML_PERF
39814043
// print timing information per ggml operation (for debugging purposes)
@@ -7040,6 +7102,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
70407102
}
70417103

70427104
ctx->kv_self.head = kv_ntok;
7105+
ctx->kv_self.size = kv_size;
7106+
7107+
ctx->kv_self.cell_max = kv_ntok;
70437108
}
70447109

70457110
const size_t nread = inp - src;

llama.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,19 @@ extern "C" {
316316
int n_threads);
317317

318318
//
319-
// KV cache API
319+
// KV cache
320320
//
321321

322322
// Returns the number of tokens in the KV cache
323323
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
324-
"avoid using this, it will be removed in the future");
324+
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
325325

326326
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
327327

328+
//
329+
// State / sessions
330+
//
331+
328332
// Returns the maximum size in bytes of the state (rng, logits, embedding
329333
// and kv_cache) - will often be smaller after compacting tokens
330334
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
@@ -342,6 +346,10 @@ extern "C" {
342346
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
343347
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
344348

349+
//
350+
// Decoding
351+
//
352+
345353
// Run the llama inference to obtain the logits and probabilities for the next token.
346354
// tokens + n_tokens is the provided batch of new tokens to process
347355
// n_past is the number of tokens to use from previous eval calls

0 commit comments

Comments
 (0)