Skip to content

Commit aa5401f

Browse files
committed
Adding blindrag
1 parent fa4e166 commit aa5401f

File tree

5 files changed

+1259
-1115
lines changed

5 files changed

+1259
-1115
lines changed

nilai-api/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ dependencies = [
1616
"fastapi[standard]>=0.115.5",
1717
"gunicorn>=23.0.0",
1818
"nilai-common",
19+
"blindrag",
20+
"nilrag",
1921
"python-dotenv>=1.0.1",
2022
"sqlalchemy>=2.0.36",
2123
"uvicorn>=0.32.1",
@@ -40,3 +42,5 @@ build-backend = "hatchling.build"
4042

4143
[tool.uv.sources]
4244
nilai-common = { workspace = true }
45+
blindrag = { path = "./blindRAG" }
46+
nilrag = { path = "./nilrag" }
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import logging
2+
from typing import Union
3+
4+
import blindrag
5+
from fastapi import HTTPException, status
6+
from nilai_common import ChatRequest, Message
7+
from blindrag.rag_vault import RAGVault
8+
from sentence_transformers import SentenceTransformer
9+
10+
logger = logging.getLogger(__name__)
11+
12+
embeddings_model = SentenceTransformer(
13+
"sentence-transformers/all-MiniLM-L6-v2", device="cpu"
14+
) # FIXME: Use a GPU model and move to a separate container
15+
16+
17+
def generate_embeddings_huggingface(
18+
chunks_or_query: Union[str, list],
19+
):
20+
"""
21+
Generate embeddings for text using a HuggingFace sentence transformer model.
22+
23+
Args:
24+
chunks_or_query (str or list): Text string(s) to generate embeddings for
25+
26+
Returns:
27+
numpy.ndarray: Array of embeddings for the input text
28+
"""
29+
embeddings = embeddings_model.encode(chunks_or_query, convert_to_tensor=False)
30+
return embeddings
31+
32+
33+
async def handle_blindrag(req: ChatRequest):
34+
"""
35+
Endpoint to process a client query.
36+
1. Get inputs from request.
37+
2. Execute blindRAG using blindrag library.
38+
3. Append top results to LLM query
39+
"""
40+
try:
41+
logger.debug("Rag is starting.")
42+
43+
# Step 1: Get inputs
44+
# Get nilDB instances
45+
if not req.blindrag or "nodes" not in req.blindrag:
46+
raise HTTPException(
47+
status_code=status.HTTP_400_BAD_REQUEST,
48+
detail="blindrag configuration is missing or invalid",
49+
)
50+
rag = await RAGVault.create_from_dict(req.blindrag)
51+
52+
# Get user query
53+
logger.debug("Extracting user query")
54+
query = None
55+
for message in req.messages:
56+
if message.role == "user":
57+
query = message.content
58+
break
59+
60+
if query is None:
61+
raise HTTPException(status_code=400, detail="No user query found")
62+
# Get number of chunks to include
63+
num_chunks = req.blindrag.get("num_chunks", 2)
64+
# Step 2: Execute blindRAG
65+
relevant_context = await rag.top_num_chunks_execute(query, num_chunks, False)
66+
# Step 3: Update system message
67+
for message in req.messages:
68+
if message.role == "system":
69+
if message.content is None:
70+
raise HTTPException(
71+
status_code=status.HTTP_400_BAD_REQUEST,
72+
detail="system message is empty",
73+
)
74+
message.content += (
75+
relevant_context # Append the context to the system message
76+
)
77+
break
78+
else:
79+
# If no system message exists, add one
80+
req.messages.insert(0, Message(role="system", content=relevant_context))
81+
logger.debug(f"System message updated with relevant context:\n {req.messages}")
82+
83+
except HTTPException as e:
84+
raise e
85+
86+
except Exception as e:
87+
logger.error("An error occurred within blindRAG: %s", str(e))
88+
raise HTTPException(
89+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
90+
)

nilai-api/src/nilai_api/routers/private.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import AsyncGenerator, Optional, Union, List, Tuple
77
from nilai_api.attestation import get_attestation_report
88
from nilai_api.handlers.nilrag import handle_nilrag
9+
from nilai_api.handlers.blindrag import handle_blindrag
910

1011
from fastapi import APIRouter, Body, Depends, HTTPException, status, Request
1112
from fastapi.responses import StreamingResponse
@@ -193,6 +194,9 @@ async def chat_completion(
193194
if req.nilrag:
194195
await handle_nilrag(req)
195196

197+
if req.blindrag:
198+
await handle_blindrag(req)
199+
196200
if req.stream:
197201
client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")
198202

packages/nilai-common/src/nilai_common/api_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ChatRequest(BaseModel):
2424
stream: Optional[bool] = False
2525
tools: Optional[Iterable[ChatCompletionToolParam]] = None
2626
nilrag: Optional[dict] = {}
27+
blindrag: Optional[dict] = {}
2728

2829

2930
class SignedChatCompletion(ChatCompletion):

0 commit comments

Comments
 (0)