|
| 1 | +import logging |
| 2 | +from typing import Union |
| 3 | + |
| 4 | +import blindrag |
| 5 | +from fastapi import HTTPException, status |
| 6 | +from nilai_common import ChatRequest, Message |
| 7 | +from blindrag.rag_vault import RAGVault |
| 8 | +from sentence_transformers import SentenceTransformer |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | +embeddings_model = SentenceTransformer( |
| 13 | + "sentence-transformers/all-MiniLM-L6-v2", device="cpu" |
| 14 | +) # FIXME: Use a GPU model and move to a separate container |
| 15 | + |
| 16 | + |
| 17 | +def generate_embeddings_huggingface( |
| 18 | + chunks_or_query: Union[str, list], |
| 19 | +): |
| 20 | + """ |
| 21 | + Generate embeddings for text using a HuggingFace sentence transformer model. |
| 22 | +
|
| 23 | + Args: |
| 24 | + chunks_or_query (str or list): Text string(s) to generate embeddings for |
| 25 | +
|
| 26 | + Returns: |
| 27 | + numpy.ndarray: Array of embeddings for the input text |
| 28 | + """ |
| 29 | + embeddings = embeddings_model.encode(chunks_or_query, convert_to_tensor=False) |
| 30 | + return embeddings |
| 31 | + |
| 32 | + |
| 33 | +async def handle_blindrag(req: ChatRequest): |
| 34 | + """ |
| 35 | + Endpoint to process a client query. |
| 36 | + 1. Get inputs from request. |
| 37 | + 2. Execute blindRAG using blindrag library. |
| 38 | + 3. Append top results to LLM query |
| 39 | + """ |
| 40 | + try: |
| 41 | + logger.debug("Rag is starting.") |
| 42 | + |
| 43 | + # Step 1: Get inputs |
| 44 | + # Get nilDB instances |
| 45 | + if not req.blindrag or "nodes" not in req.blindrag: |
| 46 | + raise HTTPException( |
| 47 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 48 | + detail="blindrag configuration is missing or invalid", |
| 49 | + ) |
| 50 | + rag = await RAGVault.create_from_dict(req.blindrag) |
| 51 | + |
| 52 | + # Get user query |
| 53 | + logger.debug("Extracting user query") |
| 54 | + query = None |
| 55 | + for message in req.messages: |
| 56 | + if message.role == "user": |
| 57 | + query = message.content |
| 58 | + break |
| 59 | + |
| 60 | + if query is None: |
| 61 | + raise HTTPException(status_code=400, detail="No user query found") |
| 62 | + # Get number of chunks to include |
| 63 | + num_chunks = req.blindrag.get("num_chunks", 2) |
| 64 | + # Step 2: Execute blindRAG |
| 65 | + relevant_context = await rag.top_num_chunks_execute(query, num_chunks, False) |
| 66 | + # Step 3: Update system message |
| 67 | + for message in req.messages: |
| 68 | + if message.role == "system": |
| 69 | + if message.content is None: |
| 70 | + raise HTTPException( |
| 71 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 72 | + detail="system message is empty", |
| 73 | + ) |
| 74 | + message.content += ( |
| 75 | + relevant_context # Append the context to the system message |
| 76 | + ) |
| 77 | + break |
| 78 | + else: |
| 79 | + # If no system message exists, add one |
| 80 | + req.messages.insert(0, Message(role="system", content=relevant_context)) |
| 81 | + logger.debug(f"System message updated with relevant context:\n {req.messages}") |
| 82 | + |
| 83 | + except HTTPException as e: |
| 84 | + raise e |
| 85 | + |
| 86 | + except Exception as e: |
| 87 | + logger.error("An error occurred within blindRAG: %s", str(e)) |
| 88 | + raise HTTPException( |
| 89 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) |
| 90 | + ) |
0 commit comments