Skip to content

Commit a5d5b6b

Browse files
authored
Add files via upload
1 parent 974031d commit a5d5b6b

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

configs.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
## variable vlaue
3+
4+
faiss_key = False
5+
vector_store = None
6+
7+
8+
## model address
9+
embedding_model_address = "" ## "shibing624/text2vec-base-chinese"
10+

model.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
model deploy : faschat
3+
run:
4+
1.python -m fastchat.serve.controller
5+
2.python -m fastchat.serve.model_worker --model-path ./chatglm2-6b --num-gpus 2 --host=0.0.0.0 --port=21002
6+
7+
calling interface : requests.post
8+
"""
9+
10+
import requests
11+
12+
13+
def get_response(text):
14+
headers = {"Content-Type": "application/json"}
15+
pload = {
16+
"model": "chatglm2-6b",
17+
"prompt": text,
18+
"stop": "###",
19+
"max_new_tokens": 8000,
20+
}
21+
print("pload",pload)
22+
response = requests.post("http://*****:21002/worker_generate_stream", headers=headers, json=pload, stream=True)
23+
# print(response.text)
24+
return response

split.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
from typing import List
3+
4+
import re
5+
from langchain.document_loaders import UnstructuredFileLoader
6+
from langchain.text_splitter import CharacterTextSplitter
7+
8+
9+
class ChineseTextSplitter(CharacterTextSplitter):
10+
def __init__(self, pdf: bool = False, **kwargs):
11+
super().__init__(**kwargs)
12+
self.pdf = pdf
13+
14+
def split_text(self, text: str) -> List[str]:
15+
if self.pdf:
16+
text = re.sub(r"\n{3,}", "\n", text)
17+
text = re.sub('\s', ' ', text)
18+
text = text.replace("\n\n", "")
19+
sent_sep_pattern = re.compile(
20+
'([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
21+
sent_list = []
22+
for ele in sent_sep_pattern.split(text):
23+
if sent_sep_pattern.match(ele) and sent_list:
24+
sent_list[-1] += ele
25+
elif ele:
26+
sent_list.append(ele)
27+
return sent_list
28+
29+
30+
31+
def load_file(filepath):
32+
print("filepath:",filepath)
33+
if filepath.endswith(".md"):
34+
loader = UnstructuredFileLoader(filepath, mode="elements")
35+
docs = loader.load()
36+
elif filepath.endswith(".pdf"):
37+
loader = UnstructuredFileLoader(filepath)
38+
textsplitter = ChineseTextSplitter(pdf=True)
39+
docs = loader.load_and_split(textsplitter)
40+
else:
41+
loader = UnstructuredFileLoader(filepath, mode="elements")
42+
textsplitter = ChineseTextSplitter(pdf=False)
43+
docs = loader.load_and_split(text_splitter=textsplitter)
44+
return loader,docs

web.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import json
2+
import time
3+
from tempfile import NamedTemporaryFile
4+
import os
5+
6+
7+
import streamlit as st
8+
from langchain.embeddings import HuggingFaceEmbeddings
9+
from langchain.vectorstores import FAISS
10+
11+
import configs
12+
from model import get_response
13+
from split import load_file
14+
15+
16+
# langchain embedding
17+
embedding = HuggingFaceEmbeddings(model_name=configs.embedding_model_address)
18+
19+
20+
st.set_page_config(page_title="LLM-RAG-WEB")
21+
st.title("LLM-RAG-WEB")
22+
23+
24+
25+
def clear_chat_history1():
26+
del st.session_state.messages
27+
st.session_state.history1 = [st.session_state.history1[0]] # 保留初始记录
28+
# placeholder.empty()
29+
30+
def clear_chat_history2():
31+
del st.session_state.messages
32+
st.session_state.history2 = []
33+
34+
def init_chat_history1():
35+
with st.chat_message("assistant", avatar='🤖'):
36+
st.markdown("您好,我是AI助手,很高兴为您服务🥰")
37+
38+
if "messages1" in st.session_state:
39+
for message in st.session_state.messages1:
40+
avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
41+
with st.chat_message(message["role"], avatar=avatar):
42+
st.markdown(message["content"])
43+
else:
44+
st.session_state.messages1 = []
45+
46+
return st.session_state.messages1
47+
48+
def init_chat_history2():
49+
with st.chat_message("assistant", avatar='🤖'):
50+
st.markdown("您好,我是AI助手,很高兴为您服务🥰")
51+
52+
if "messages2" in st.session_state:
53+
for message in st.session_state.messages2:
54+
avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
55+
with st.chat_message(message["role"], avatar=avatar):
56+
st.markdown(message["content"])
57+
else:
58+
st.session_state.messages2 = []
59+
60+
return st.session_state.messages2
61+
62+
63+
# 初始化变量
64+
if 'history1' not in st.session_state:
65+
st.session_state.history1 = [["Human","你的昵称为小杰"],["Assistant","好的,小杰明白"]]
66+
67+
# 初始化变量
68+
if 'history2' not in st.session_state:
69+
st.session_state.history2 = []
70+
71+
# 初始化 session_state
72+
if "enter_pressed" not in st.session_state:
73+
st.session_state.enter_pressed = False
74+
75+
76+
77+
def main():
78+
79+
if "vector_store" not in st.session_state:
80+
st.session_state.vector_store = configs.vector_store
81+
82+
if "faiss_key" not in st.session_state:
83+
st.session_state.faiss_key = configs.faiss_key
84+
print("first faiss_key:",configs.faiss_key)
85+
86+
# 创建侧边栏布局
87+
sidebar_selection = st.sidebar.selectbox("选择对话类型", ("模型对话", "文件对话"))
88+
89+
90+
if sidebar_selection == "模型对话":
91+
st.session_state.faiss_key = False
92+
messages1 = init_chat_history1()
93+
print("history1:",st.session_state.history1)
94+
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
95+
with st.chat_message("user", avatar='🧑‍💻'):
96+
st.markdown(prompt)
97+
messages1.append({"role": "user", "content": prompt})
98+
print(f"[user] {prompt}", flush=True)
99+
with st.chat_message("assistant", avatar='🤖'):
100+
placeholder = st.empty()
101+
102+
103+
st.session_state.history1.append(["Human",prompt])
104+
st.session_state.history1.append(["Assistant",None])
105+
print("history1:",st.session_state.history1)
106+
start=time.time()
107+
results = get_response(st.session_state.history1)
108+
for chunk in results.iter_lines(chunk_size=1024,decode_unicode=False, delimiter=b"\0"):
109+
if chunk:
110+
# print(chunk.decode("utf-8"))
111+
response = json.loads(chunk.decode("utf-8"))["text"]
112+
# print(response)
113+
114+
placeholder.markdown(response[(len(prompt)+1):])
115+
end=time.time()
116+
cost = end-start
117+
length = len(response[(len(prompt)+1):])
118+
print(f"{length/cost}tokens/s")
119+
# print(prompt,response[(len(prompt)+1):])
120+
st.session_state.history1[-1][1] =response[(len(prompt)+1):]
121+
122+
123+
messages1.append({"role": "assistant", "content": response[(len(prompt)+1):]})
124+
125+
126+
print(json.dumps(messages1, ensure_ascii=False), flush=True)
127+
128+
129+
st.button("清空对话", on_click=clear_chat_history1)
130+
131+
elif sidebar_selection == "文件对话":
132+
## uploaded_file
133+
uploaded_file = st.file_uploader("Choose a file")
134+
135+
print("st.session_state.faiss_key:",st.session_state.faiss_key)
136+
if not st.session_state.faiss_key:
137+
st.session_state.messages2 = []
138+
messages2 = init_chat_history2()
139+
else:
140+
messages2 = init_chat_history2()
141+
142+
143+
if uploaded_file is not None:
144+
145+
if not st.session_state.faiss_key:
146+
print("faiss_key1:",st.session_state.faiss_key)
147+
148+
# 临时文件保留原文件格式比如pdf后缀
149+
temp_file = NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1])
150+
temp_file.write(uploaded_file.getvalue())
151+
# 构造包含扩展名的临时文件路径
152+
file_path = temp_file.name
153+
with st.spinner('Reading file...'):
154+
text_loader, texts = load_file(file_path)
155+
st.success('Finished reading file.')
156+
temp_file.close()
157+
## 保存文件向量
158+
159+
st.session_state.vector_store = FAISS.from_documents(texts, embedding)
160+
st.success('Finished save embedding.')
161+
st.session_state.faiss_key = True
162+
163+
164+
if st.session_state.faiss_key:
165+
print("faiss_key2:",st.session_state.faiss_key)
166+
167+
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
168+
with st.chat_message("user", avatar='🧑‍💻'):
169+
st.markdown(prompt)
170+
messages2.append({"role": "user", "content": prompt})
171+
print(f"[user] {prompt}", flush=True)
172+
with st.chat_message("assistant", avatar='🤖'):
173+
placeholder = st.empty()
174+
sim_result = st.session_state.vector_store.similarity_search(prompt)[0].page_content
175+
new_prompt = f"""请根据下面单引号内信息简短回答:{prompt}? '{sim_result}' \n"""
176+
# new_prompt =f"""基于以下已知信息,简洁和专业的来回答用户的问题。
177+
178+
# 已知内容:
179+
# {sim_result}
180+
# 问题:{prompt}"""
181+
st.session_state.history2 = [["Human","你的昵称为小杰"],["Assistant","好的,小杰明白"]]
182+
st.session_state.history2.append(["Human",new_prompt])
183+
st.session_state.history2.append(["Assistant",None])
184+
print("history2:",st.session_state.history2)
185+
start=time.time()
186+
results = get_response(st.session_state.history2)
187+
for chunk in results.iter_lines(chunk_size=1024,decode_unicode=False, delimiter=b"\0"):
188+
if chunk:
189+
# print(chunk.decode("utf-8"))
190+
response = json.loads(chunk.decode("utf-8"))["text"]
191+
# print(response)
192+
193+
placeholder.markdown(response[(len(new_prompt)+1):])
194+
end=time.time()
195+
cost = end-start
196+
length = len(response[(len(new_prompt)+1):])
197+
print(f"{length/cost}tokens/s")
198+
# print(prompt,response[(len(prompt)+1):])
199+
st.session_state.history1[-1][1] =response[(len(new_prompt)+1):]
200+
201+
202+
messages2.append({"role": "assistant", "content": response[(len(new_prompt)+1):]})
203+
print(json.dumps(messages2, ensure_ascii=False), flush=True)
204+
205+
st.button("清空对话", on_click=clear_chat_history2)
206+
207+
208+
209+
if __name__ == "__main__":
210+
main()

0 commit comments

Comments
 (0)