Skip to content

llama : support qwen3 rerank and embeddings #14029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,13 @@ struct common_init_result common_init_from_params(common_params & params) {
ok = false;
}

bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;

if (!has_eos && !has_sep) {
if (has_rerank_prompt) {
// OK, do nothing
} else if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
Expand Down
69 changes: 69 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3061,6 +3064,72 @@ def prepare_tensors(self):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3

# extra logic for rerank models
token_false_id: int | None = None
token_true_id: int | None = None
sep_token_id: int = 0
is_tied_embeddings: bool = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()

def _find_rerank_config(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
logger.info(f"gguf: token_false_id = {self.token_false_id}, token_true_id = {self.token_true_id}")
logger.info(f"gguf: sep_token_id = {self.sep_token_id}")
logger.info(f"gguf: is_tied_embeddings = {self.is_tied_embeddings}")

def set_gguf_parameters(self):
super().set_gguf_parameters()
is_rerank = self.token_false_id is not None and self.token_true_id is not None
if is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
"<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}\n"
"<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
}])

def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor
assert self.token_false_id is not None and self.token_true_id is not None
false_row = data_torch[self.token_false_id]
true_row = data_torch[self.token_true_id]
return torch.stack([true_row, false_row], dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_rerank = self.token_false_id is not None and self.token_true_id is not None

if not name.startswith("model."):
name = "model." + name

if is_rerank:
if self.is_tied_embeddings and "embed_tokens" in name:
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch)),
(self.map_tensor_name(name), data_torch),
]
if not self.is_tied_embeddings and "lm_head" in name:
# this is the lm_head tensor, we need to extract the cls_out tensor
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch))]

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
]


Expand Down
3 changes: 2 additions & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template." }, // FIXME: cannot add %s because it will be replaced by arch name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you can use LLM_KV_TOKENIZER_CHAT_TEMPLATE with suffix:

llama.cpp/src/llama-arch.cpp

Lines 1722 to 1725 in c02f53d

if (suffix != nullptr) {
name += ".";
name += suffix;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing this but it doesn't work, maybe it's buggy somewhere else:

    const auto key = name
        ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking in the wrong place, this is where it's broken:

llama.cpp/src/llama-arch.cpp

Lines 1709 to 1712 in e83ba3e

std::string LLM_KV::operator()(llm_kv kv) const {
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #14050

{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
Expand Down Expand Up @@ -629,6 +629,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_CLS_OUT, "cls.output" }, // rerank
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
Expand Down
63 changes: 37 additions & 26 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
if (!cparams.embeddings) {
return;
}

const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token

if (is_last_tok) {
// set output to the last token of each sequence
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
Expand All @@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));

std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);

for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];

// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");

for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];

if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
}
}
}
}

if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}

} else {
// set output to first token of each sequence
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
Expand All @@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));

std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);

for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];

// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");

for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];

if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
}
}

for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
}

Expand Down Expand Up @@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
}

ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);

auto & cur = inp->cls;

Expand Down Expand Up @@ -1577,10 +1583,15 @@ void llm_graph_context::build_pooling(
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
}
} else if (cls_out) {
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
GGML_ASSERT(cls_out_b != nullptr);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
if (arch == LLM_ARCH_QWEN3) {
cur = ggml_mul_mat(ctx0, cls_out, inp);
cur = ggml_log(ctx0, ggml_soft_max(ctx0, cur)); // qwen3 uses log_softmax
} else {
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
GGML_ASSERT(cls_out_b != nullptr);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
}
} else {
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
}
Expand Down
3 changes: 2 additions & 1 deletion src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ class llm_graph_input_mean : public llm_graph_input_i {

class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : arch(arch), cparams(cparams) {}
virtual ~llm_graph_input_cls() = default;

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * cls; // I32 [n_batch]

const llm_arch arch;
const llama_cparams & cparams;
};

Expand Down
10 changes: 8 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_QWEN3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);

switch (hparams.n_layer) {
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break;
Expand Down Expand Up @@ -2468,6 +2470,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output rerank
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
Expand Down Expand Up @@ -7057,7 +7062,7 @@ struct llm_build_qwen3 : public llm_graph_context {
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}

if (il == n_layer - 1) {
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
Expand Down Expand Up @@ -13788,7 +13793,8 @@ uint64_t llama_model_size(const llama_model * model) {
}

const char * llama_model_chat_template(const llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
const auto key = name
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name)
Comment on lines +13796 to +13797
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto key = name
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name)
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)

I wonder how long this has been broken?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, has never worked it seems, broken since it was introduced in #11016

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in used by any of the examples so we don't know if it works in the first place (probably used in downstream project, but idk)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a PR.

: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
Expand Down
11 changes: 4 additions & 7 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4704,22 +4704,19 @@ int main(int argc, char ** argv) {
return;
}

llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];

// create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
auto inputs = tokenize_rerank(ctx_server.model, query, documents);
tasks.reserve(documents.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
task.prompt_tokens = server_tokens(inputs[i], ctx_server.mctx != nullptr);
tasks.push_back(std::move(task));
}

Expand Down
59 changes: 42 additions & 17 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,23 +260,48 @@ static size_t validate_utf8(const std::string& text) {
// template utils
//

// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;

// Get EOS token - use SEP token as fallback if EOS is not available
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}

result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
result.push_back(eos_token);
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(eos_token);
// format and tokenize rerank task:
// - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
// - using prompt: <rerank_prefix>query<rerank_suffix>doc
static std::vector<llama_tokens> tokenize_rerank(const struct llama_model * model, const std::string & query, const std::vector<std::string> & documents) {
const llama_vocab * vocab = llama_model_get_vocab(model);
std::vector<llama_tokens> result;

for (const auto & doc : documents) {
if (llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL) {
// Get EOS token - use SEP token as fallback if EOS is not available
llama_tokens tok;
llama_tokens tok_query = common_tokenize(vocab, query, false, false);
llama_tokens tok_doc = common_tokenize(vocab, doc, false, false);
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}

tok.reserve(doc.size() + query.size() + 4);
tok.push_back(llama_vocab_bos(vocab));
tok.insert(tok.end(), tok_query.begin(), tok_query.end());
tok.push_back(eos_token);
tok.push_back(llama_vocab_sep(vocab));
tok.insert(tok.end(), tok_doc.begin(), tok_doc.end());
tok.push_back(eos_token);

result.push_back(std::move(tok));
} else {
// using prompt template
const char * tmpl = llama_model_chat_template(model, "rerank");
if (tmpl == nullptr) {
throw std::runtime_error("model does not have rerank template");
}

std::string prompt = tmpl;
// TODO: may not be efficient to call string_replace_all twice
string_replace_all(prompt, "{query}", query);
string_replace_all(prompt, "{document}", doc);
llama_tokens tok = common_tokenize(vocab, prompt, true, false);
result.push_back(std::move(tok));
}
}

return result;
}
Expand Down
Loading