-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathform_helper.py
150 lines (122 loc) · 5.32 KB
/
form_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sqlite3
from typing import List
import json
import streamlit as st
from dotenv import load_dotenv
from langchain.schema import BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import END, MessageGraph
from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun
from langchain.globals import set_debug
from langgraph.prebuilt import ToolInvocation, ToolExecutor
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.memory import ConversationBufferWindowMemory
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_community.document_loaders.text import TextLoader
from langchain_community.vectorstores.faiss import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_core.documents import Document
from playground.documents_store import DocumentsStore
db = None
document_store: DocumentsStore = DocumentsStore(embeddings=OpenAIEmbeddings(), documents_dir_path="data_resources")
@tool
def geheimzahl_tool():
"""Das ist dein Tool - nutze es nur für die Geheimzahl
Dieses Tool nimmt keine Parameter, also übergib bitte keine Argumente."""
print("invoking geheimzahl tool")
return "Die Geheimzahl ist 123987"
@tool
def document_tool(message: Annotated[str, "Eine Zusammenfassung der Vorhaben des Nutzers"]) -> Annotated[dict[str, List[str]|None], "Ein Dictionary mit Pfad zur Datei als Key und Wert ist eine Liste von Texten aus Dokumenten"]:
"""Wenn nach einem Fest oder Veranstaltung gefragt wird, nutze dieses Tool. Teile dem Nutzer immer mit welche Dateien wichtig sind."""
return document_store.retrieve(message)
# docs_and_scores = db.similarity_search_with_score(message)
# print(len(docs_and_scores))
# return_docs = []
# for doc in docs_and_scores:
# print(doc)
# return_docs.append(doc)
# return return_docs
def rag_initialize():
embeddings = OpenAIEmbeddings()
loader = TextLoader("./playground/state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
global db
db = FAISS.from_documents(docs, embeddings)
#print(db.index.ntotal)
titles: list = ["./data_resources/Schankerlaubnis.pdf", "./data_resources/STK-Ehrenamtsleitfaden_2023_Online.pdf", "./data_resources/Unterlagen_Veranstaltungsorganisation_09.01.24.pdf"]
documents: list = []
for title in titles:
loader = PyPDFLoader(file_path=title)
document = loader.load()
documents += document
db.add_documents(documents=documents)
ddgs = DuckDuckGoSearchRun()
tools = [geheimzahl_tool, document_tool]
load_dotenv()
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
def initialize_app():
if "app" not in st.session_state:
set_debug(True)
# rag_initialize()
llm = ChatOpenAI()
conn = sqlite3.connect(":memory:", check_same_thread=False)
model_with_tools = llm.bind_tools(tools)
memory = SqliteSaver(conn=conn)
def agent(state):
print("invoking agent", state)
messages = state["messages"]
response = model_with_tools.invoke(messages)
print(messages, response)
return {"messages": [response]}
def tool(state):
print("invoking tool", state)
messages = state["messages"][-1]
tool_messages = []
for tool_call in messages.additional_kwargs["tool_calls"]:
action = ToolInvocation(
tool=tool_call["function"]["name"],
tool_input=json.loads(tool_call["function"]["arguments"] or {}),
)
response = tool_executor.invoke(action)
tool_messages.append(
ToolMessage(
content=str(response),
tool_call_id=tool_call["id"],
name=tool_call["function"]["name"],
)
)
return {"messages": tool_messages}
workflow = StateGraph(AgentState)
workflow.add_node("chatbot", agent)
workflow.set_entry_point("chatbot")
workflow.add_node("search", tool)
workflow.add_conditional_edges(
"chatbot", should_continue, {"continue": "search", "end": END}
)
workflow.add_edge("search", "chatbot")
graph = workflow.compile(checkpointer=memory)
def formatter(state: List[BaseMessage]):
if state is None or len(state["messages"]) == 0:
return "No Messages"
return state["messages"][-1].content
app = graph | formatter
st.session_state.app = app
tool_executor = ToolExecutor(tools)
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
if "tool_calls" not in last_message.additional_kwargs:
return "end"
else:
return "continue"