Skip to content

Commit 5be6c80

Browse files
llama : remove token functions with context args in favor of model (#3720)
* added `llama_model_token_*` variants to all the `llama_token_*` functions. * added `LLAMA_API` * formatting Co-authored-by: Georgi Gerganov <[email protected]> * removed old `llama_token` functions * changed 3 more functions to take in model - `llama_token_get_text` - `llama_token_get_score` - `llama_token_get_type` * added back docs * fixed main.cpp * changed token functions to use new model variants * changed token functions to use new model variants --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 6336701 commit 5be6c80

File tree

16 files changed

+80
-78
lines changed

16 files changed

+80
-78
lines changed

common/common.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
880880
}
881881

882882
if (params.ignore_eos) {
883-
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
883+
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
884884
}
885885

886886
{
887887
LOG("warming up the model with an empty run\n");
888888

889-
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
889+
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
890890
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
891891
llama_kv_cache_tokens_rm(lctx, -1, -1);
892892
llama_reset_timings(lctx);
@@ -941,7 +941,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
941941
}
942942

943943
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
944-
const llama_token bos_id = llama_token_bos(ctx);
944+
const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
945945

946946
std::string piece;
947947
std::string result;
@@ -1186,7 +1186,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
11861186
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
11871187
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
11881188

1189-
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
1189+
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
11901190
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
11911191
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
11921192

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,15 @@ llama_token llama_sampling_sample(
147147

148148
// apply penalties
149149
if (!prev.empty()) {
150-
const float nl_logit = logits[llama_token_nl(ctx_main)];
150+
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
151151

152152
llama_sample_repetition_penalties(ctx_main, &cur_p,
153153
prev.data() + prev.size() - penalty_last_n,
154154
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
155155

156156
if (!penalize_nl) {
157157
for (size_t idx = 0; idx < cur_p.size; idx++) {
158-
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
158+
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
159159
cur_p.data[idx].logit = nl_logit;
160160
break;
161161
}

common/train.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ int64_t get_example_targets_batch(
236236
int64_t used_samples = 0;
237237

238238
ggml_set_f32(target_probs, 0.0f);
239-
llama_token bos = llama_token_bos(lctx);
240-
llama_token eos = llama_token_eos(lctx);
239+
llama_token bos = llama_token_bos(llama_get_model(lctx));
240+
llama_token eos = llama_token_eos(llama_get_model(lctx));
241241
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
242242
for (int k=0; k<n_batch; ++k) {
243243
// printf("%s: batch %d\n", __func__, k);
@@ -924,7 +924,7 @@ size_t tokenize_file(
924924
for (llama_token token=0; token < n_vocab; ++token) {
925925
max_token_text_size = std::max(
926926
max_token_text_size,
927-
strlen(llama_token_get_text(lctx, token)));
927+
strlen(llama_token_get_text(llama_get_model(lctx), token)));
928928
}
929929

930930
// upper bound of context byte length.

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
180180
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
181181

182182
// is it an end of stream? -> mark the stream as finished
183-
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
183+
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
184184
i_batch[i] = -1;
185185
LOG_TEE("\n");
186186
if (n_parallel > 1) {

examples/beam-search/beam-search.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct beam_search_callback_data {
4747
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
4848
// For example, eob can be flagged due to maximum token length, stop words, etc.
4949
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
50-
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
50+
return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
5151
}
5252

5353
// Function matching type llama_beam_search_callback_fn_t.

examples/infill/infill.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,22 @@ int main(int argc, char ** argv) {
246246
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
247247
inp_sfx.erase(inp_sfx.begin());
248248
}
249-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
249+
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
250250
if (add_bos) {
251-
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
251+
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
252252
}
253-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
253+
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
254254
embd_inp = inp_pfx;
255255
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
256-
embd_inp.push_back(llama_token_middle(ctx));
256+
embd_inp.push_back(llama_token_middle(model));
257257

258258
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
259259
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
260260
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
261261

262262
// Should not run without any tokens
263263
if (embd_inp.empty()) {
264-
embd_inp.push_back(llama_token_bos(ctx));
264+
embd_inp.push_back(llama_token_bos(model));
265265
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
266266
}
267267

@@ -577,10 +577,10 @@ int main(int argc, char ** argv) {
577577
if ((int) embd_inp.size() <= n_consumed) {
578578

579579
// deal with eot token in infill mode
580-
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
580+
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
581581
if(is_interacting && !params.interactive_first) {
582582
// print an eot token
583-
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
583+
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
584584
}
585585
fflush(stdout);
586586
printf("\n");
@@ -627,14 +627,14 @@ int main(int argc, char ** argv) {
627627
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
628628
inp_sfx.erase(inp_sfx.begin());
629629
}
630-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
630+
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
631631
if (add_bos) {
632-
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
632+
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
633633
}
634-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
634+
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
635635
embd_inp = inp_pfx;
636636
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
637-
embd_inp.push_back(llama_token_middle(ctx));
637+
embd_inp.push_back(llama_token_middle(model));
638638
embd.clear();
639639
embd_guidance.clear();
640640
n_remain = params.n_predict;
@@ -644,7 +644,7 @@ int main(int argc, char ** argv) {
644644
is_interacting = false;
645645
}
646646
// deal with end of text token in interactive mode
647-
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
647+
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
648648
LOG("found EOS token\n");
649649

650650
if (params.interactive) {
@@ -661,7 +661,7 @@ int main(int argc, char ** argv) {
661661

662662
if (params.input_prefix_bos) {
663663
LOG("adding input prefix BOS token\n");
664-
embd_inp.push_back(llama_token_bos(ctx));
664+
embd_inp.push_back(llama_token_bos(model));
665665
}
666666

667667
std::string buffer;
@@ -724,7 +724,7 @@ int main(int argc, char ** argv) {
724724
}
725725

726726
// end of text token
727-
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) {
727+
if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
728728
break;
729729
}
730730

@@ -736,7 +736,7 @@ int main(int argc, char ** argv) {
736736
}
737737
}
738738
if (!params.interactive && n_remain <= 0) {
739-
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
739+
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
740740
fflush(stdout);
741741
}
742742

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ struct sql_printer : public printer {
933933
};
934934

935935
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
936-
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
936+
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
937937
int n_processed = 0;
938938

939939
llama_set_n_threads(ctx, n_threads, n_threads);
@@ -946,7 +946,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
946946
}
947947

948948
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
949-
llama_token token = llama_token_bos(ctx);
949+
llama_token token = llama_token_bos(llama_get_model(ctx));
950950

951951
llama_set_n_threads(ctx, n_threads, n_threads);
952952

examples/llava/llava-utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
137137
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
138138
int id = sample_id(ctx_llama, params);
139139
static std::string ret;
140-
if (id == llama_token_eos(ctx_llama)) {
140+
if (id == llama_token_eos(llama_get_model(ctx_llama))) {
141141
ret = "</s>";
142142
} else {
143143
ret = llama_token_to_piece(ctx_llama, id);

examples/main/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
248248

249249
// Should not run without any tokens
250250
if (embd_inp.empty()) {
251-
embd_inp.push_back(llama_token_bos(ctx));
251+
embd_inp.push_back(llama_token_bos(model));
252252
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
253253
}
254254

@@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
693693
}
694694

695695
// deal with end of text token in interactive mode
696-
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
696+
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
697697
LOG("found EOS token\n");
698698

699699
if (params.interactive) {
@@ -720,7 +720,7 @@ int main(int argc, char ** argv) {
720720

721721
if (params.input_prefix_bos) {
722722
LOG("adding input prefix BOS token\n");
723-
embd_inp.push_back(llama_token_bos(ctx));
723+
embd_inp.push_back(llama_token_bos(model));
724724
}
725725

726726
std::string buffer;
@@ -804,7 +804,7 @@ int main(int argc, char ** argv) {
804804
}
805805

806806
// end of text token
807-
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
807+
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
808808
LOG_TEE(" [end of text]\n");
809809
break;
810810
}

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ int main(int argc, char ** argv) {
347347
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
348348

349349
if (client.n_decoded > 2 &&
350-
(id == llama_token_eos(ctx) ||
350+
(id == llama_token_eos(model) ||
351351
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
352352
client.response.find("User:") != std::string::npos ||
353353
client.response.find('\n') != std::string::npos)) {

0 commit comments

Comments
 (0)