-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathchains.py
47 lines (43 loc) · 1.72 KB
/
chains.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from langchain import OpenAI, LLMChain, PromptTemplate
from langchain.memory import ConversationBufferWindowMemory, ConversationBufferMemory
from langchain.agents import initialize_agent, Tool
from langchain.chat_models import ChatOpenAI
from langchain.utilities import GoogleSerperAPIWrapper
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
# get a chat LLM chain, following a prompt template
def get_chat_chain():
# create prompt from a template
template = open('template', 'r').read()
prompt = PromptTemplate(
input_variables=["history", "human_input"],
template=template
)
# create a LLM chain with conversation buffer memory
return LLMChain(
llm=OpenAI(temperature=0),
prompt=prompt,
verbose=True,
memory=ConversationBufferWindowMemory(k=10),
)
# get a chat chain that uses Serper API to search using Google Search
def get_search_agent():
# set up the tool
search = GoogleSerperAPIWrapper()
tools = [ Tool(name = "Current Search", func=search.run, description="search")]
# create and return the chat agent
return initialize_agent(
tools=tools,
llm=ChatOpenAI(),
agent="chat-conversational-react-description",
verbose=True,
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True)
)
def get_qa_chain():
vectordb = Chroma(persist_directory='.', embedding_function=OpenAIEmbeddings())
retriever = vectordb.as_retriever()
return RetrievalQA.from_chain_type(
llm=ChatOpenAI(temperature=0),
chain_type="stuff",
retriever=retriever)