Skip to content

Commit 4d76d76

Browse files
committed
llama : extend llama_kv_cache API
1 parent 6952a46 commit 4d76d76

File tree

4 files changed

+84
-32
lines changed

4 files changed

+84
-32
lines changed

examples/embd-input/embd-input-lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
7979
if (n_eval > n_batch) {
8080
n_eval = n_batch;
8181
}
82-
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
82+
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
8383
if (llama_decode(ctx, batch, params.n_threads)) {
8484
fprintf(stderr, "%s : failed to eval\n", __func__);
8585
return false;

examples/perplexity/perplexity.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ static void write_logfile(
7979
static std::vector<float> softmax(const std::vector<float>& logits) {
8080
std::vector<float> probs(logits.size());
8181
float max_logit = logits[0];
82-
for (float v : logits) max_logit = std::max(max_logit, v);
82+
for (float v : logits) {
83+
max_logit = std::max(max_logit, v);
84+
}
8385
double sum_exp = 0.0;
8486
for (size_t i = 0; i < logits.size(); i++) {
8587
// Subtract the maximum logit value from the current logit value for numerical stability
@@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
8890
sum_exp += exp_logit;
8991
probs[i] = exp_logit;
9092
}
91-
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
93+
for (size_t i = 0; i < probs.size(); i++) {
94+
probs[i] /= sum_exp;
95+
}
9296
return probs;
9397
}
9498

9599
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
96100
float max_logit = logits[0];
97-
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
101+
for (int i = 1; i < n_vocab; ++i) {
102+
max_logit = std::max(max_logit, logits[i]);
103+
}
98104
double sum_exp = 0.0;
99-
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
105+
for (int i = 0; i < n_vocab; ++i) {
106+
sum_exp += expf(logits[i] - max_logit);
107+
}
100108
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
101109
}
102110

@@ -107,7 +115,8 @@ static void process_logits(
107115
std::mutex mutex;
108116
int counter = 0;
109117
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
110-
double local_nll = 0, local_nll2 = 0;
118+
double local_nll = 0;
119+
double local_nll2 = 0;
111120
while (true) {
112121
std::unique_lock<std::mutex> lock(mutex);
113122
int i = counter++;
@@ -125,10 +134,13 @@ static void process_logits(
125134
prob_history[i] = results.prob;
126135
}
127136
};
128-
for (auto & w : workers) w = std::thread(compute);
137+
for (auto & w : workers) {
138+
w = std::thread(compute);
139+
}
129140
compute();
130-
for (auto & w : workers) w.join();
131-
141+
for (auto & w : workers) {
142+
w.join();
143+
}
132144
}
133145

134146
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
@@ -151,8 +163,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
151163
return {std::move(tokens), 0., {}, {}};
152164
}
153165

154-
std::vector<float> logit_history;
155-
std::vector<float> prob_history;
166+
std::vector<float> logit_history;
167+
std::vector<float> prob_history;
156168

157169
logit_history.resize(tokens.size());
158170
prob_history.resize(tokens.size());
@@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
194206

195207
const auto t_start = std::chrono::high_resolution_clock::now();
196208

209+
// clear the KV cache
210+
llama_kv_cache_keep_seq(ctx, -1);
211+
197212
for (int j = 0; j < num_batches; ++j) {
198213
const int batch_start = start + j * n_batch;
199214
const int batch_size = std::min(end - batch_start, n_batch);
@@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
319334

320335
const auto t_start = std::chrono::high_resolution_clock::now();
321336

337+
// clear the KV cache
338+
llama_kv_cache_keep_seq(ctx, -1);
339+
322340
for (int j = 0; j < num_batches; ++j) {
323341
const int batch_start = start + j * n_batch;
324342
const int batch_size = std::min(end - batch_start, n_batch);
@@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
549567
query_embd.resize(32);
550568
}
551569

570+
// clear the KV cache
571+
llama_kv_cache_keep_seq(ctx, -1);
572+
552573
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
553574
if (logits.empty()) {
554575
fprintf(stderr, "%s : failed to eval\n", __func__);

llama.cpp

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,7 +1316,8 @@ static bool llama_kv_cache_find_slot(
13161316
return true;
13171317
}
13181318

1319-
void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
1319+
void llama_kv_cache_update(struct llama_kv_cache & cache) {
1320+
// compute new cell_max
13201321
cache.cell_max = 0;
13211322

13221323
for (uint32_t i = 0; i < cache.size; i++) {
@@ -1326,18 +1327,40 @@ void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
13261327
}
13271328
}
13281329

1329-
void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) {
1330-
cache.head = p0;
1330+
void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
1331+
if (c0 < 0) c0 = 0;
1332+
if (c1 < 0) c1 = cache.size;
13311333

1332-
if (p0 < 0) p0 = 0;
1333-
if (p1 < 0) p1 = cache.size;
1334-
1335-
for (int32_t i = p0; i < p1; ++i) {
1334+
for (int32_t i = c0; i < c1; ++i) {
13361335
cache.cells[i].pos = -1;
13371336
cache.cells[i].seq_id.clear();
13381337
}
13391338

1340-
llama_kv_cache_update_cell_max(cache);
1339+
llama_kv_cache_update(cache);
1340+
}
1341+
1342+
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
1343+
for (uint32_t i = 0; i < cache.size; ++i) {
1344+
if (cache.cells[i].has_seq_id(seq_id)) {
1345+
cache.cells[i].seq_id.erase(seq_id);
1346+
if (cache.cells[i].seq_id.empty()) {
1347+
cache.cells[i].pos = -1;
1348+
}
1349+
}
1350+
}
1351+
1352+
llama_kv_cache_update(cache);
1353+
}
1354+
1355+
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
1356+
for (uint32_t i = 0; i < cache.size; ++i) {
1357+
if (!cache.cells[i].has_seq_id(seq_id)) {
1358+
cache.cells[i].pos = -1;
1359+
cache.cells[i].seq_id.clear();
1360+
}
1361+
}
1362+
1363+
llama_kv_cache_update(cache);
13411364
}
13421365

13431366
//
@@ -3968,10 +3991,6 @@ static bool llama_eval_internal(
39683991
batch.seq_id = seq_id.data();
39693992
}
39703993

3971-
if (batch.clear_kv) {
3972-
llama_kv_cache_clear(kv_self, 0, -1);
3973-
}
3974-
39753994
if (!llama_kv_cache_find_slot(kv_self, batch)) {
39763995
return false;
39773996
}
@@ -6803,8 +6822,16 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
68036822
return ctx->kv_self.head;
68046823
}
68056824

6806-
void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) {
6807-
llama_kv_cache_clear(ctx->kv_self, p0, p1);
6825+
void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) {
6826+
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
6827+
}
6828+
6829+
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
6830+
llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
6831+
}
6832+
6833+
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
6834+
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
68086835
}
68096836

68106837
// Returns the *maximum* size of the state
@@ -7203,7 +7230,7 @@ int llama_eval(
72037230
uint32_t n_tokens,
72047231
int n_past,
72057232
int n_threads) {
7206-
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
7233+
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
72077234

72087235
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
72097236
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@@ -7226,9 +7253,9 @@ int llama_eval_embd(
72267253
uint32_t n_tokens,
72277254
int n_past,
72287255
int n_threads) {
7229-
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
7256+
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
72307257

7231-
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, };
7258+
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, };
72327259

72337260
if (!llama_eval_internal(*ctx, batch, n_threads)) {
72347261
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@@ -7259,7 +7286,6 @@ struct llama_batch llama_batch_get_one(
72597286
/*all_pos_0 =*/ pos_0,
72607287
/*all_pos_1 =*/ 1,
72617288
/*all_seq_id =*/ seq_id,
7262-
/*clear_kv =*/ pos_0 == 0,
72637289
};
72647290
}
72657291

llama.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ extern "C" {
8484
llama_pos all_pos_0; // used if pos == NULL
8585
llama_pos all_pos_1; // used if pos == NULL
8686
llama_seq_id all_seq_id; // used if seq_id == NULL
87-
88-
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
8987
} llama_seq;
9088

9189
enum llama_log_level {
@@ -323,7 +321,14 @@ extern "C" {
323321
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
324322
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
325323

326-
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
324+
// Remove all tokens between cells [c0, c1)
325+
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
326+
327+
// Removes all tokens that belong to the specified sequence
328+
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
329+
330+
// Removes all tokens that do not belong to the specified sequence
331+
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
327332

328333
//
329334
// State / sessions

0 commit comments

Comments
 (0)