Skip to content

Commit d29e769

Browse files
committed
llama : unified KV cache + batch inference API
1 parent fad5693 commit d29e769

File tree

10 files changed

+315
-236
lines changed

10 files changed

+315
-236
lines changed

common/common.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
436436
params.use_mmap = false;
437437
} else if (arg == "--numa") {
438438
params.numa = true;
439-
} else if (arg == "--export") {
440-
params.export_cgraph = true;
441439
} else if (arg == "--verbose-prompt") {
442440
params.verbose_prompt = true;
443441
} else if (arg == "-r" || arg == "--reverse-prompt") {
@@ -685,7 +683,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
685683
printf(" Not recommended since this is both slower and uses more VRAM.\n");
686684
#endif // GGML_USE_CUBLAS
687685
#endif
688-
printf(" --export export the computation graph to 'llama.ggml'\n");
689686
printf(" --verbose-prompt print prompt before generation\n");
690687
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
691688
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@@ -782,7 +779,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
782779
{
783780
LOG("warming up the model with an empty run\n");
784781

785-
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
782+
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
786783
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
787784
llama_reset_timings(lctx);
788785
}
@@ -1182,7 +1179,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
11821179
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
11831180
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
11841181
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
1185-
fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false");
11861182
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
11871183
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
11881184
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ struct gpt_params {
111111
bool use_mmap = true; // use mmap for faster loads
112112
bool use_mlock = false; // use mlock to keep model in memory
113113
bool numa = false; // attempt optimizations that help on some NUMA systems
114-
bool export_cgraph = false; // export the computation graph
115114
bool verbose_prompt = false; // print prompt tokens before generation
116115
};
117116

examples/beam-search/beam-search.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ int main(int argc, char ** argv)
158158
}
159159
std::cout << std::flush;
160160

161-
int n_past = llama_get_kv_cache_token_count(ctx);
161+
int n_past = 0;
162+
162163
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
163164
{
164165
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );

examples/main/main.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,6 @@ int main(int argc, char ** argv) {
198198
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
199199
}
200200

201-
// export the cgraph and exit
202-
if (params.export_cgraph) {
203-
llama_eval_export(ctx, "llama.ggml");
204-
llama_free(ctx);
205-
llama_free_model(model);
206-
207-
return 0;
208-
}
209-
210201
std::string path_session = params.path_prompt_cache;
211202
std::vector<llama_token> session_tokens;
212203

examples/perplexity/perplexity.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
400400
return {tokens, ppl, logit_history, prob_history};
401401
}
402402

403-
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
403+
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int> & tokens, int n_past, int n_batch,
404404
int n_vocab, int n_thread) {
405405
std::vector<float> result;
406406
result.reserve(tokens.size() * n_vocab);

examples/simple/simple.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ int main(int argc, char ** argv) {
7373

7474
const int n_gen = std::min(32, max_context_size);
7575

76-
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
76+
int n_cur = 0;
77+
78+
while (n_cur < n_gen) {
7779
// evaluate the transformer
7880

79-
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
81+
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
8082
fprintf(stderr, "%s : failed to eval\n", __func__);
8183
return 1;
8284
}

ggml.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16(
1246212462
return;
1246312463
}
1246412464

12465-
const int n_past = ((int32_t *) dst->op_params)[0];
12465+
//const int n_past = ((int32_t *) dst->op_params)[0];
1246612466
const int n_head = ((int32_t *) dst->op_params)[1];
1246712467
float max_bias;
1246812468
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1246912469

12470-
assert(n_past >= 0);
12471-
1247212470
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
1247312471
const int ne1 = src0->ne[1]; // seq_len_without_past
1247412472
const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16(
1248312481
//const int nb3 = src0->nb[3];
1248412482

1248512483
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12486-
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12484+
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
1248712485
GGML_ASSERT(n_head == ne2);
1248812486

1248912487
// add alibi to src0 (KQ_scaled)

0 commit comments

Comments
 (0)