Skip to content

Commit 58f9833

Browse files
pavanjavapavanmantha
andauthored
Langchain llmasjudge (#67)
* -llam as judge with langchain * -llam as judge with langchain includes phoenix observability --------- Co-authored-by: pavanmantha <[email protected]>
1 parent c2a14c4 commit 58f9833

File tree

25 files changed

+730
-3
lines changed

25 files changed

+730
-3
lines changed

bootstraprag/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def create(project_name, framework, template, observability):
4343

4444
elif framework == 'langchain':
4545
template_choices = [
46-
'simple-rag'
46+
'simple-rag',
47+
'rag-with-hyde',
48+
'llm-as-judge'
4749
]
4850
elif framework == 'standalone-qdrant':
4951
framework = 'qdrant'
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
OLLAMA_BASE_URL="http://localhost:11434"
2+
OLLAMA_LLM_MODEL="llama3.2:latest"
3+
EMBEDDING_MODEL="snowflake/snowflake-arctic-embed-s"
4+
5+
QDRANT_DB_URL="http://localhost:6333/"
6+
QDRANT_DB_KEY="th3s3cr3tk3y"
7+
COLLECTION_NAME="crag_langchain_collection"
8+
9+
LIT_SERVER_PORT=8000
10+
LIT_SERVER_WORKERS_PER_DEVICE=2
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Use the official Python image from the Docker Hub
2+
FROM python:3.9-slim
3+
4+
# Set the working directory in the container
5+
WORKDIR /app
6+
7+
# Copy the requirements file to the container
8+
COPY requirements.txt .
9+
10+
# Install the required dependencies
11+
RUN pip install --no-cache-dir -r requirements.txt
12+
13+
# Copy the current directory contents into the container at /app
14+
COPY . .
15+
16+
# Set environment variables (you can replace these with values from your .env file or other configs)
17+
ENV QDRANT_URL='http://host.docker.internal:6333' \
18+
OLLAMA_BASE_URL='http://host.docker.internal:11434'
19+
20+
# Expose port 8000 for external access
21+
EXPOSE 8000
22+
23+
# Command to run your application
24+
CMD ["python", "api_server.py"]

bootstraprag/templates/langchain/llm_as_judge/__init__.py

Whitespace-only changes.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from abc import ABC
2+
from dotenv import load_dotenv, find_dotenv
3+
from llm_as_judge import LLMasJudge
4+
import litserve as ls
5+
import os
6+
7+
_ = load_dotenv(find_dotenv())
8+
9+
10+
class LLMasJudgeAPI(ls.LitAPI, ABC):
11+
def __init__(self):
12+
self.llm_as_judge: LLMasJudge = None
13+
self.FILE_PATH = 'data/mlops.pdf'
14+
self.COLLECTION_NAME = os.environ.get('COLLECTION_NAME')
15+
self.QDRANT_URL = os.environ.get('QDRANT_DB_URL')
16+
self.QDRANT_API_KEY = os.environ.get('QDRANT_DB_KEY')
17+
self.operation_name: str = ''
18+
19+
def setup(self, devices):
20+
self.llm_as_judge = LLMasJudge(
21+
file_path=self.FILE_PATH,
22+
collection_name=self.COLLECTION_NAME,
23+
qdrant_url=self.QDRANT_URL,
24+
qdrant_api_key=self.QDRANT_API_KEY
25+
)
26+
27+
def decode_request(self, request, **kwargs):
28+
self.operation_name = request["operation"]
29+
return request["query"]
30+
31+
def predict(self, query: str):
32+
if self.operation_name == 'retrieval_grader':
33+
return self.llm_as_judge.retrieval_grader(question=query)
34+
elif self.operation_name == 'generate':
35+
return self.llm_as_judge.generate(question=query)
36+
elif self.operation_name == 'hallucination_grader':
37+
generation = self.llm_as_judge.generate(question=query)
38+
return self.llm_as_judge.hallucination_grader(question=query, generation=generation)
39+
elif self.operation_name == 'answer_grader':
40+
generation = self.llm_as_judge.generate(question=query)
41+
return self.llm_as_judge.answer_grader(question=query, generation=generation)
42+
43+
def encode_response(self, output, **kwargs):
44+
return {'response': output}
45+
46+
47+
if __name__ == '__main__':
48+
api = LLMasJudgeAPI()
49+
server = ls.LitServer(lit_api=api, api_path='/v1/chat/completions',
50+
workers_per_device=int(os.environ.get('LIT_SERVER_WORKERS_PER_DEVICE')))
51+
server.run(port=os.environ.get('LIT_SERVER_PORT'))
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import requests
15+
16+
response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
17+
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
retrieval_grader_template = """You are a grader assessing relevance
2+
of a retrieved document to a user question. If the document contains any information or keywords related to the user
3+
question,grade it as relevant. This is a very lenient test - the document does not need to fully answer the question
4+
to be considered relevant. Give a binary score 'yes' or 'no' to indicate whether the document is relevant to the question.
5+
Also provide a brief explanation for your decision.
6+
7+
Return your response as a JSON with two keys: 'score' (either 'yes' or 'no') and 'explanation'.
8+
9+
Here is the retrieved document:
10+
{document}
11+
12+
Here is the user question:
13+
{question}
14+
"""
15+
16+
hallucination_grading_template = """You are a grader assessing whether
17+
an answer is grounded in / supported by a set of facts. Give a binary score 'yes' or 'no' score to indicate
18+
whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
19+
single key 'score' and no preamble or explanation.
20+
21+
Here are the facts:
22+
{documents}
23+
24+
Here is the answer:
25+
{generation}
26+
"""
27+
28+
answer_generating_template = """You are an assistant for question-answering tasks.
29+
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
30+
Use three sentences maximum and keep the answer concise:
31+
Question: {question}
32+
Context: {context}
33+
Answer:
34+
"""
35+
36+
answer_grading_template = """You are a grader assessing whether an
37+
answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
38+
useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
39+
40+
Here is the answer:
41+
{generation}
42+
43+
Here is the question: {question}
44+
"""
616 KB
Binary file not shown.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
3+
from langchain_community.document_loaders import PyMuPDFLoader
4+
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
5+
from langchain_core.prompts import PromptTemplate
6+
from langchain_core.runnables.utils import Output
7+
from langchain_ollama import OllamaLLM, ChatOllama
8+
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
9+
from langchain_qdrant import QdrantVectorStore, RetrievalMode
10+
from langchain_text_splitters import RecursiveCharacterTextSplitter
11+
from qdrant_client import QdrantClient
12+
from dotenv import load_dotenv, find_dotenv
13+
from qdrant_client.http.models import VectorParams, Distance
14+
from typing import List, Any
15+
from custom_templates import (
16+
retrieval_grader_template,
17+
hallucination_grading_template,
18+
answer_generating_template,
19+
answer_grading_template
20+
)
21+
22+
23+
class LLMasJudge:
24+
def __init__(self, file_path: str, collection_name: str, qdrant_url: str, qdrant_api_key: str):
25+
load_dotenv(find_dotenv())
26+
self.file_path = file_path
27+
self.collection_name = collection_name
28+
self.qdrant_url = qdrant_url
29+
self.qdrant_api_key = qdrant_api_key
30+
31+
self.model = OllamaLLM(model=os.environ.get("OLLAMA_LLM_MODEL"), base_url=os.environ.get("OLLAMA_BASE_URL"))
32+
self.embedding = FastEmbedEmbeddings(model=os.environ.get("EMBEDDING_MODEL"))
33+
self.client = QdrantClient(url=self.qdrant_url, api_key=self.qdrant_api_key)
34+
# LLM
35+
self.llm = ChatOllama(model=os.environ.get('OLLAMA_LLM_MODEL'), format="json")
36+
self.vector_store: QdrantVectorStore = None
37+
self.documents = self.load_and_split_documents()
38+
self.setup_qdrant()
39+
40+
def load_and_split_documents(self) -> List[Any]:
41+
loader = PyMuPDFLoader(file_path=self.file_path)
42+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=20)
43+
return loader.load_and_split(text_splitter=text_splitter)
44+
45+
def setup_qdrant(self):
46+
if not self.client.collection_exists(collection_name=self.collection_name):
47+
try:
48+
self.client.create_collection(
49+
collection_name=self.collection_name,
50+
vectors_config={
51+
"content": VectorParams(size=384, distance=Distance.COSINE)
52+
}
53+
)
54+
self.load_data_to_qdrant()
55+
except Exception as e:
56+
print(f"Exception: {str(e)}")
57+
else:
58+
self.vector_store = QdrantVectorStore.from_existing_collection(
59+
url=self.qdrant_url,
60+
api_key=self.qdrant_api_key,
61+
collection_name=self.collection_name,
62+
embedding=self.embedding,
63+
retrieval_mode=RetrievalMode.DENSE,
64+
vector_name="content"
65+
)
66+
67+
def load_data_to_qdrant(self):
68+
vector_store: QdrantVectorStore = QdrantVectorStore(client=self.client, collection_name=self.collection_name,
69+
embedding=self.embedding, vector_name="content",
70+
retrieval_mode=RetrievalMode.DENSE)
71+
vector_store.add_documents(
72+
documents=self.documents
73+
)
74+
self.vector_store = vector_store
75+
76+
def retrieval_grader(self, question: str):
77+
prompt = PromptTemplate(
78+
template=retrieval_grader_template,
79+
input_variables=["question", "document"],
80+
)
81+
retrieval_grader = prompt | self.llm | JsonOutputParser()
82+
docs = self.vector_store.as_retriever().invoke(question)
83+
doc_txt = docs[1].page_content
84+
retrieval_grading_response = retrieval_grader.invoke({"question": question, "document": doc_txt})
85+
return retrieval_grading_response
86+
87+
def generate(self, question: str) -> Output:
88+
prompt = PromptTemplate(
89+
template=answer_generating_template,
90+
input_variables=["question", "context"]
91+
)
92+
93+
# Chain
94+
rag_chain = prompt | self.llm | StrOutputParser()
95+
96+
# Run
97+
docs = self.vector_store.as_retriever().invoke(question)
98+
generation: Output = rag_chain.invoke({"context": docs, "question": question})
99+
return generation
100+
101+
def hallucination_grader(self, question: str, generation):
102+
prompt = PromptTemplate(
103+
template=hallucination_grading_template,
104+
input_variables=["generation", "documents"],
105+
)
106+
docs = self.vector_store.as_retriever().invoke(question)
107+
hallucination_grader = prompt | self.llm | JsonOutputParser()
108+
hallucination_grading_response = hallucination_grader.invoke({"documents": docs, "generation": generation})
109+
return hallucination_grading_response
110+
111+
def answer_grader(self, question: str, generation: str):
112+
prompt = PromptTemplate(
113+
template=answer_grading_template,
114+
input_variables=["generation", "question"]
115+
)
116+
answer_grader = prompt | self.llm | JsonOutputParser()
117+
answer_grading_response = answer_grader.invoke({"question": question, "generation": generation})
118+
return answer_grading_response
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
3+
from llm_as_judge import LLMasJudge
4+
from dotenv import load_dotenv, find_dotenv
5+
6+
load_dotenv(find_dotenv())
7+
8+
llm_as_judge = LLMasJudge(
9+
file_path='data/mlops.pdf',
10+
collection_name=os.environ.get("COLLECTION_NAME"),
11+
qdrant_url=os.environ.get("QDRANT_DB_URL"),
12+
qdrant_api_key=os.environ.get("QDRANT_DB_KEY")
13+
)
14+
15+
q = "what are challenges of mlops?"
16+
llm_as_judge.retrieval_grader(question=q)
17+
ans = llm_as_judge.generate(question=q)
18+
print(ans)
19+
llm_as_judge.hallucination_grader(question=q, generation=ans)
20+
llm_as_judge.answer_grader(question=q, generation=ans)

0 commit comments

Comments
 (0)