Skip to content

Commit a21ff6c

Browse files
committed
context : enable reranking with encode()
ggml-ci
1 parent 9770efa commit a21ff6c

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/llama-context.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) {
793793
} break;
794794
case LLAMA_POOLING_TYPE_RANK:
795795
{
796-
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
797-
// wait for an encoder model that requires this pooling type in order to test it
798-
// https://github.com/ggerganov/llama.cpp/pull/9510
799-
GGML_ABORT("RANK pooling not implemented yet");
800-
}
796+
// extract the rerank score - a single float per sequence
797+
auto & embd_seq_out = embd_seq;
798+
799+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
800+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
801+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
802+
continue;
803+
}
804+
embd_seq_out[seq_id].resize(1);
805+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
806+
}
807+
} break;
801808
case LLAMA_POOLING_TYPE_UNSPECIFIED:
802809
{
803810
GGML_ABORT("unknown pooling type");

0 commit comments

Comments
 (0)