Skip to content

Commit 0b3c63e

Browse files
committed
First implementation of the ReRanker endpoint.
1 parent 2f08fc3 commit 0b3c63e

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

app/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,10 @@ class QueryMultipleBody(BaseModel):
4242
query: str
4343
file_ids: List[str]
4444
k: int = 4
45+
46+
47+
class QueryMultipleDocs(BaseModel):
48+
query: str
49+
docs: List[str]
50+
config: dict
51+
k: int = 4

app/routes/document_routes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import aiofiles.os
77
from shutil import copyfileobj
88
from typing import List, Iterable
9+
from rerankers import Reranker, Document as ReRankDocument
910
from fastapi import (
1011
APIRouter,
1112
Request,
@@ -29,6 +30,7 @@
2930
QueryRequestBody,
3031
DocumentResponse,
3132
QueryMultipleBody,
33+
QueryMultipleDocs,
3234
)
3335
from app.services.vector_store.async_pg_vector import AsyncPgVector
3436
from app.utils.document_loader import (
@@ -648,3 +650,45 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
648650
traceback.format_exc(),
649651
)
650652
raise HTTPException(status_code=500, detail=str(e))
653+
654+
655+
@router.post("/rerank")
656+
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
657+
try:
658+
rk = Reranker(
659+
body.config.get("model_name", "flashrank"),
660+
model_type=body.config.get("model_type"),
661+
lang=body.config.get("lang"),
662+
api_provider=body.config.get("api_provider"),
663+
api_key=body.config.get("api_key"),
664+
)
665+
666+
667+
docs = []
668+
for i, d in enumerate(body.docs):
669+
if isinstance(d, str):
670+
docs.append(ReRankDocument(text=d, doc_id=i))
671+
else:
672+
docs.append(ReRankDocument(
673+
text=d.get("text", ""),
674+
doc_id=d.get("doc_id", i),
675+
metadata=d.get("metadata", {}) or {}
676+
))
677+
678+
top_k = body.k
679+
680+
results = rk.rank(query=body.query, docs=docs)
681+
items = results.top_k(top_k) if top_k else results
682+
683+
return [
684+
[getattr(r.document, "text", None), r.score]
685+
for r in items
686+
]
687+
except Exception as e:
688+
logger.error(
689+
"Error in reranking documents | Query: %s | Error: %s | Traceback: %s",
690+
body.query,
691+
str(e),
692+
traceback.format_exc(),
693+
)
694+
raise HTTPException(status_code=500, detail=str(e))

docker-compose.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ services:
1717
- DB_PORT=5432
1818
ports:
1919
- "8000:8000"
20+
runtime: nvidia
2021
volumes:
2122
- ./uploads:/app/uploads
23+
- ~/.cache/huggingface:/root/.cache/huggingface:rw
2224
depends_on:
2325
- db
2426
env_file:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ python-magic==0.4.27
3636
python-pptx==1.0.2
3737
xlrd==2.0.2
3838
pydantic==2.9.2
39+
rerankers[transformers]==0.6.0

0 commit comments

Comments
 (0)