Skip to content

Commit b3ae978

Browse files
committed
server : do not normalize embeddings when there is no pooling
1 parent fd05a79 commit b3ae978

File tree

5 files changed

+15
-6
lines changed

5 files changed

+15
-6
lines changed

common/common.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
17801780
break;
17811781
case 0: // max absolute
17821782
for (int i = 0; i < n; i++) {
1783-
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
1783+
if (sum < std::abs(inp[i])) {
1784+
sum = std::abs(inp[i]);
1785+
}
17841786
}
17851787
sum /= 32760.0; // make an int16 range
17861788
break;

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
603603
// Embedding utils
604604
//
605605

606-
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
606+
// TODO: repace embd_norm with an enum
607+
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
607608

608609
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
609610

examples/gritlm/gritlm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
7575
}
7676

7777
std::vector<float> emb_norm(emb_unorm.size());
78-
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
78+
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
7979
result.push_back(emb_norm);
8080

8181
#ifdef GRIT_DEBUG

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
107107
}
108108

109109
float * out = output + batch.seq_id[i][0] * n_embd;
110-
common_embd_normalize(embd, out, n_embd);
110+
common_embd_normalize(embd, out, n_embd, 2);
111111
}
112112
}
113113

examples/server/server.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,8 +2042,14 @@ struct server_context {
20422042
continue;
20432043
}
20442044

2045-
common_embd_normalize(embd, embd_res.data(), n_embd);
2046-
res->embedding.push_back(embd_res);
2045+
// normalize only when there is pooling
2046+
// TODO: configurable
2047+
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2048+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2049+
res->embedding.push_back(embd_res);
2050+
} else {
2051+
res->embedding.push_back({ embd, embd + n_embd });
2052+
}
20472053
}
20482054

20492055
SLT_DBG(slot, "%s", "sending embeddings\n");

0 commit comments

Comments
 (0)