Skip to content

Commit f015b26

Browse files
committed
llama : more robust cell_max heuristic + wip shift
1 parent 4d76d76 commit f015b26

File tree

3 files changed

+39
-52
lines changed

3 files changed

+39
-52
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,8 @@ int main(int argc, char ** argv) {
977977

978978
test t(inst, lmodel, ctx);
979979

980+
llama_kv_cache_keep_seq(ctx, -1);
981+
980982
// warmup run
981983
if (t.n_prompt > 0) {
982984
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
@@ -986,6 +988,8 @@ int main(int argc, char ** argv) {
986988
}
987989

988990
for (int i = 0; i < params.reps; i++) {
991+
llama_kv_cache_keep_seq(ctx, -1);
992+
989993
uint64_t t_start = get_time_ns();
990994
if (t.n_prompt > 0) {
991995
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);

llama.cpp

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,9 +1023,6 @@ struct llama_kv_cache {
10231023
uint32_t head = 0;
10241024
uint32_t size = 0;
10251025

1026-
// largest index of an occupied cell (used for a basic optimization heuristic)
1027-
uint32_t cell_max = 0;
1028-
10291026
std::vector<llama_kv_cell> cells;
10301027

10311028
struct ggml_tensor * k = NULL;
@@ -1229,8 +1226,6 @@ static bool llama_kv_cache_init(
12291226
cache.head = 0;
12301227
cache.size = n_ctx;
12311228

1232-
cache.cell_max = 0;
1233-
12341229
cache.cells.clear();
12351230
cache.cells.resize(n_ctx);
12361231

@@ -1316,15 +1311,16 @@ static bool llama_kv_cache_find_slot(
13161311
return true;
13171312
}
13181313

1319-
void llama_kv_cache_update(struct llama_kv_cache & cache) {
1320-
// compute new cell_max
1321-
cache.cell_max = 0;
1314+
int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
1315+
int32_t res = 0;
13221316

13231317
for (uint32_t i = 0; i < cache.size; i++) {
1324-
if (cache.cells[i].pos >= 0) {
1325-
cache.cell_max = i + 1;
1318+
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
1319+
res = i + 1;
13261320
}
13271321
}
1322+
1323+
return res;
13281324
}
13291325

13301326
void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
@@ -1335,8 +1331,6 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
13351331
cache.cells[i].pos = -1;
13361332
cache.cells[i].seq_id.clear();
13371333
}
1338-
1339-
llama_kv_cache_update(cache);
13401334
}
13411335

13421336
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
@@ -1348,8 +1342,6 @@ void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
13481342
}
13491343
}
13501344
}
1351-
1352-
llama_kv_cache_update(cache);
13531345
}
13541346

13551347
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
@@ -1359,8 +1351,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
13591351
cache.cells[i].seq_id.clear();
13601352
}
13611353
}
1354+
}
1355+
1356+
void llama_kv_cache_shift(
1357+
struct llama_context & ctx,
1358+
llama_seq_id seq_id,
1359+
llama_pos p0,
1360+
llama_pos p1,
1361+
llama_pos delta) {
1362+
auto & hparams = ctx.model.hparams;
1363+
auto & cache = ctx.kv_self;
13621364

1363-
llama_kv_cache_update(cache);
1365+
for (uint32_t i = 0; i < cache.size; ++i) {
1366+
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1367+
cache.cells[i].pos += delta;
1368+
}
1369+
}
13641370
}
13651371

13661372
//
@@ -2587,7 +2593,7 @@ static struct ggml_cgraph * llm_build_llama(
25872593
const int n_gpu_layers = model.n_gpu_layers;
25882594

25892595
const int32_t n_tokens = batch.n_tokens;
2590-
const int32_t n_kv = kv_self.cell_max + n_tokens;
2596+
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
25912597

25922598
auto & buf_compute = lctx.buf_compute;
25932599

@@ -2678,13 +2684,6 @@ static struct ggml_cgraph * llm_build_llama(
26782684
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
26792685
}
26802686
}
2681-
2682-
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
2683-
for (int i = n_kv; i < n_ctx; ++i) {
2684-
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
2685-
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
2686-
}
2687-
}
26882687
}
26892688
}
26902689
}
@@ -2952,7 +2951,7 @@ static struct ggml_cgraph * llm_build_baichaun(
29522951
const int n_gpu_layers = model.n_gpu_layers;
29532952

29542953
const int32_t n_tokens = batch.n_tokens;
2955-
const int32_t n_kv = kv_self.cell_max + n_tokens;
2954+
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
29562955

29572956
auto & buf_compute = lctx.buf_compute;
29582957

@@ -3043,13 +3042,6 @@ static struct ggml_cgraph * llm_build_baichaun(
30433042
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
30443043
}
30453044
}
3046-
3047-
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3048-
for (int i = n_kv; i < n_ctx; ++i) {
3049-
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3050-
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
3051-
}
3052-
}
30533045
}
30543046
}
30553047
}
@@ -3334,7 +3326,7 @@ static struct ggml_cgraph * llm_build_falcon(
33343326
const int n_gpu_layers = model.n_gpu_layers;
33353327

33363328
const int32_t n_tokens = batch.n_tokens;
3337-
const int32_t n_kv = kv_self.cell_max + n_tokens;
3329+
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
33383330

33393331
auto & buf_compute = lctx.buf_compute;
33403332

@@ -3425,13 +3417,6 @@ static struct ggml_cgraph * llm_build_falcon(
34253417
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
34263418
}
34273419
}
3428-
3429-
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3430-
for (int i = n_kv; i < n_ctx; ++i) {
3431-
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3432-
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
3433-
}
3434-
}
34353420
}
34363421
}
34373422
}
@@ -3671,7 +3656,7 @@ static struct ggml_cgraph * llm_build_starcoder(
36713656
const float norm_eps = hparams.f_norm_eps;
36723657

36733658
const int32_t n_tokens = batch.n_tokens;
3674-
const int32_t n_kv = kv_self.cell_max + n_tokens;
3659+
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
36753660

36763661
auto & buf_compute = lctx.buf_compute;
36773662

@@ -3754,13 +3739,6 @@ static struct ggml_cgraph * llm_build_starcoder(
37543739
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
37553740
}
37563741
}
3757-
3758-
// TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3759-
for (int i = n_kv; i < n_ctx; ++i) {
3760-
if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
3761-
GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
3762-
}
3763-
}
37643742
}
37653743
}
37663744
}
@@ -4055,8 +4033,7 @@ static bool llama_eval_internal(
40554033
#endif
40564034

40574035
// update the kv ring buffer
4058-
lctx.kv_self.head += n_tokens;
4059-
lctx.kv_self.cell_max = std::max(lctx.kv_self.cell_max, lctx.kv_self.head);
4036+
lctx.kv_self.head += n_tokens;
40604037

40614038
#ifdef GGML_PERF
40624039
// print timing information per ggml operation (for debugging purposes)
@@ -6834,6 +6811,10 @@ void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
68346811
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
68356812
}
68366813

6814+
void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815+
llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta);
6816+
}
6817+
68376818
// Returns the *maximum* size of the state
68386819
size_t llama_get_state_size(const struct llama_context * ctx) {
68396820
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
@@ -7130,8 +7111,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
71307111

71317112
ctx->kv_self.head = kv_ntok;
71327113
ctx->kv_self.size = kv_size;
7133-
7134-
ctx->kv_self.cell_max = kv_ntok;
71357114
}
71367115

71377116
const size_t nread = inp - src;

llama.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ extern "C" {
321321
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
322322
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
323323

324-
// Remove all tokens between cells [c0, c1)
324+
// Remove all tokens data of cells in [c0, c1)
325325
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
326326

327327
// Removes all tokens that belong to the specified sequence
@@ -330,6 +330,10 @@ extern "C" {
330330
// Removes all tokens that do not belong to the specified sequence
331331
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
332332

333+
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
334+
// If the KV cache is RoPEd, the KV data is updated accordingly
335+
LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
336+
333337
//
334338
// State / sessions
335339
//

0 commit comments

Comments
 (0)