-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from rajveer43/ml
Add GenAI Conversational Search
- Loading branch information
Showing
1 changed file
with
63 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
import os | ||
from langchain_google_genai import GoogleGenerativeAIEmbeddings | ||
import google.generativeai as genai | ||
from langchain.vectorstores import FAISS | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain.chains.question_answering import load_qa_chain | ||
from langchain.prompts import PromptTemplate | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
os.getenv("GOOGLE_API_KEY") | ||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | ||
|
||
def get_text_chunks(text): | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000) | ||
chunks = text_splitter.split_text(text) | ||
return chunks | ||
|
||
|
||
def get_vector_store(text_chunks): | ||
embeddings = GoogleGenerativeAIEmbeddings(model = "models/embedding-001") | ||
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) | ||
vector_store.save_local("faiss_index") | ||
|
||
|
||
def get_conversational_chain(): | ||
|
||
prompt_template = """ | ||
Answer the question as detailed as possible from the provided context, make sure to provide all the details, if the answer is not in | ||
provided context just say, "answer is not available in the context", don't provide the wrong answer\n\n | ||
Context:\n {context}?\n | ||
Question: \n{question}\n | ||
Answer: | ||
""" | ||
|
||
model = ChatGoogleGenerativeAI(model="gemini-pro", | ||
temperature=0.3) | ||
|
||
prompt = PromptTemplate(template = prompt_template, input_variables = ["context", "question"]) | ||
chain = load_qa_chain(model, chain_type="stuff", prompt=prompt) | ||
|
||
return chain | ||
|
||
|
||
|
||
def user_input(user_question): | ||
embeddings = GoogleGenerativeAIEmbeddings(model = "models/embedding-001") | ||
|
||
new_db = FAISS.load_local("faiss_index", embeddings) | ||
docs = new_db.similarity_search(user_question) | ||
|
||
chain = get_conversational_chain() | ||
|
||
|
||
response = chain( | ||
{"input_documents":docs, "question": user_question} | ||
, return_only_outputs=True) | ||
|
||
print(response) | ||
st.write("Reply: ", response["output_text"]) | ||
|