Skip to content

Commit 80259c9

Browse files
authored
feat(python): add cls and mean pooling (#402)
1 parent c6c5e45 commit 80259c9

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

backends/python/server/text_embeddings_server/models/default_model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,14 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
4343
kwargs["position_ids"] = batch.position_ids
4444

4545
output = self.model(**kwargs)
46-
embedding = output[0][:, 0]
46+
47+
if self.pooling_mode == "cls":
48+
embedding = output[0][:, 0]
49+
elif self.pooling_mode == "mean":
50+
embedding = output[0].mean(dim=1)
51+
else:
52+
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")
53+
4754
cpu_results = embedding.view(-1).tolist()
4855

4956
return [

0 commit comments

Comments
 (0)