Skip to content

Commit 9f42e75

Browse files
committed
llama : add new llama_decode() API that works with llama_batch
1 parent 58bb511 commit 9f42e75

File tree

13 files changed

+146
-75
lines changed

13 files changed

+146
-75
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
780780
LOG("warming up the model with an empty run\n");
781781

782782
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
783-
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
783+
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
784784
llama_reset_timings(lctx);
785785
}
786786

examples/beam-search/beam-search.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ int main(int argc, char ** argv)
160160

161161
int n_past = 0;
162162

163-
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
163+
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
164164
{
165165
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
166166
return 1;

examples/embd-input/embd-input-lib.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){
7979
if (n_eval > n_batch) {
8080
n_eval = n_batch;
8181
}
82-
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
82+
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
83+
if (llama_decode(ctx, batch, params.n_threads)) {
8384
fprintf(stderr, "%s : failed to eval\n", __func__);
8485
return false;
8586
}
@@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
100101
if (n_eval > params.n_batch) {
101102
n_eval = params.n_batch;
102103
}
103-
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
104+
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
104105
fprintf(stderr, "%s : failed to eval\n", __func__);
105106
return false;
106107
}

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
7777

7878
while (!embd_inp.empty()) {
7979
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
80-
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
80+
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
8181
fprintf(stderr, "%s : failed to eval\n", __func__);
8282
return 1;
8383
}

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,15 +891,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
891891
int n_processed = 0;
892892
while (n_processed < n_prompt) {
893893
int n_tokens = std::min(n_prompt - n_processed, n_batch);
894-
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
894+
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
895895
n_processed += n_tokens;
896896
}
897897
}
898898

899899
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
900900
llama_token token = llama_token_bos(ctx);
901901
for (int i = 0; i < n_gen; i++) {
902-
llama_eval(ctx, &token, 1, n_past + i, n_threads);
902+
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
903903
}
904904
}
905905

examples/main/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ int main(int argc, char ** argv) {
571571

572572
for (int i = 0; i < input_size; i += params.n_batch) {
573573
int n_eval = std::min(input_size - i, params.n_batch);
574-
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
574+
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
575575
LOG_TEE("%s : failed to eval\n", __func__);
576576
return 1;
577577
}
@@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
588588

589589
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
590590

591-
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
591+
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
592592
LOG_TEE("%s : failed to eval\n", __func__);
593593
return 1;
594594
}

examples/perplexity/perplexity.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
199199
const int batch_size = std::min(end - batch_start, n_batch);
200200

201201
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
202-
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
202+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
203203
//fprintf(stderr, "%s : failed to eval\n", __func__);
204204
return {tokens, -1, logit_history, prob_history};
205205
}
@@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
331331
tokens[batch_start] = llama_token_bos(ctx);
332332
}
333333

334-
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
334+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
335335
fprintf(stderr, "%s : failed to eval\n", __func__);
336336
return {tokens, -1, logit_history, prob_history};
337337
}
@@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
409409
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
410410
size_t n_tokens = tokens.size() - i_chunk * n_batch;
411411
n_tokens = std::min(n_tokens, size_t(n_batch));
412-
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
412+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
413413
fprintf(stderr, "%s : failed to eval\n", __func__);
414414
return {};
415415
}

examples/save-load-state/save-load-state.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ int main(int argc, char ** argv) {
3434
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
3535

3636
// init
37-
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
37+
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
3838
if (model == nullptr) {
3939
return 1;
4040
}
41-
auto ctx = llama_new_context_with_model(model, lparams);
41+
auto * ctx = llama_new_context_with_model(model, lparams);
4242
if (ctx == nullptr) {
4343
llama_free_model(model);
4444
return 1;
@@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
5353
}
5454

5555
// evaluate prompt
56-
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
56+
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
5757

5858
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
5959
n_past += n_prompt_tokens;
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
7777
printf("\n%s", params.prompt.c_str());
7878

7979
for (auto i = 0; i < params.n_predict; i++) {
80-
auto logits = llama_get_logits(ctx);
80+
auto * logits = llama_get_logits(ctx);
8181
auto n_vocab = llama_n_vocab(ctx);
8282
std::vector<llama_token_data> candidates;
8383
candidates.reserve(n_vocab);
@@ -90,7 +90,7 @@ int main(int argc, char ** argv) {
9090
last_n_tokens_data.push_back(next_token);
9191

9292
printf("%s", next_token_str.c_str());
93-
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
93+
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
9494
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
9595
llama_free(ctx);
9696
llama_free_model(model);
@@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
105105
llama_free(ctx);
106106

107107
// make new context
108-
auto ctx2 = llama_new_context_with_model(model, lparams);
108+
auto * ctx2 = llama_new_context_with_model(model, lparams);
109109

110110
// Load state (rng, logits, embedding and kv_cache) from file
111111
{
@@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
137137

138138
// second run
139139
for (auto i = 0; i < params.n_predict; i++) {
140-
auto logits = llama_get_logits(ctx2);
140+
auto * logits = llama_get_logits(ctx2);
141141
auto n_vocab = llama_n_vocab(ctx2);
142142
std::vector<llama_token_data> candidates;
143143
candidates.reserve(n_vocab);
@@ -150,7 +150,7 @@ int main(int argc, char ** argv) {
150150
last_n_tokens_data.push_back(next_token);
151151

152152
printf("%s", next_token_str.c_str());
153-
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
153+
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
154154
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
155155
llama_free(ctx2);
156156
llama_free_model(model);

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ struct llama_server_context
434434
{
435435
n_eval = params.n_batch;
436436
}
437-
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
437+
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
438438
{
439439
LOG_ERROR("failed to eval", {
440440
{"n_eval", n_eval},

examples/simple/simple.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ int main(int argc, char ** argv) {
7676
while (n_cur < n_gen) {
7777
// evaluate the transformer
7878

79-
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
79+
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) {
8080
fprintf(stderr, "%s : failed to eval\n", __func__);
8181
return 1;
8282
}

0 commit comments

Comments
 (0)