Skip to content

Commit 054555d

Browse files
committed
v0.2 multi-round chat added
1 parent 7f667b9 commit 054555d

File tree

1 file changed

+81
-32
lines changed

1 file changed

+81
-32
lines changed

main.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import redis
1414
import uvicorn
1515
from dotenv import load_dotenv
16-
from fastapi import FastAPI
16+
from fastapi import Depends, FastAPI
1717
from fastapi.middleware.cors import CORSMiddleware
1818

1919
# from fastapi.responses import HTMLResponse
@@ -39,6 +39,7 @@
3939
db_password = os.getenv("DB_PASSWORD")
4040

4141
GPT_MODEL = "gpt-4"
42+
GPT_TEMPERATURE = 0.3
4243

4344

4445
class DecimalEncoder(json.JSONEncoder):
@@ -56,6 +57,7 @@ class Conversation:
5657
"""
5758
存储对话信息
5859
"""
60+
5961
def __init__(self):
6062
self.conversation_history: List[Dict] = []
6163

@@ -64,6 +66,20 @@ def add_message(self, role, content):
6466
self.conversation_history.append(message)
6567

6668

69+
redis_pool = redis.ConnectionPool(host='localhost', port=6379, db=1)
70+
r = redis.Redis(connection_pool=redis_pool)
71+
72+
73+
def get_redis():
74+
"""获取redis db=1的连接,以待刷新对话缓存"""
75+
try:
76+
yield r
77+
except Exception as e:
78+
print("error: ", e)
79+
# finally:
80+
# r.close()
81+
82+
6783
app = FastAPI()
6884

6985
# 添加 CORS 中间件,允许跨域
@@ -76,6 +92,15 @@ def add_message(self, role, content):
7692
)
7793

7894

95+
@app.get("/refresh")
96+
def refresh_page(r: redis.Redis = Depends(get_redis)):
97+
"""刷新页面,清空redis缓存"""
98+
r.flushdb()
99+
# redis_pool = redis.ConnectionPool(host='localhost', port=6379, db=1)
100+
# r = redis.Redis(connection_pool=redis_pool)
101+
return {"message": "Redis db=1 cache cleared!"}
102+
103+
79104
@app.get("/")
80105
def read_root():
81106
return {"Hello": "World"}
@@ -101,7 +126,7 @@ async def request(val: List[dict[str, str]], call_function: bool) \
101126
"model": GPT_MODEL,
102127
"messages": val,
103128
"max_tokens": 3000,
104-
"temperature": 0.5,
129+
"temperature": GPT_TEMPERATURE,
105130
"top_p": 1,
106131
"n": 1,
107132
"stream": True,
@@ -126,7 +151,7 @@ async def request(val: List[dict[str, str]], call_function: bool) \
126151
# ]
127152
async with AsyncClient() as client:
128153
async with client.stream(
129-
"POST", url, headers=headers, json=params, timeout=60
154+
"POST", url, headers=headers, json=params, timeout=60
130155
) as response:
131156
async for line in response.aiter_lines():
132157
if not line.strip():
@@ -222,15 +247,13 @@ async def event_generator():
222247

223248
# 将执行结果存入redis,在本async方法中,无法直接返回执行结果,因为返回的是一个异步生成器
224249
# 因为对应前端方法,也不适合需要传参的asyncio的Future或者Queue
225-
redis_conn = redis.Redis(host='localhost', port=6379, db=0)
226250
if "Value" in sql_result or "GasFee" in sql_result or "GasPrice" \
227251
in sql_result:
228252
# 处理Decimal类型
229-
redis_conn.set('sql_executed_result',
230-
json.dumps(executed_result, cls=DecimalEncoder))
253+
r.set('sql_executed_result',
254+
json.dumps(executed_result, cls=DecimalEncoder))
231255
else:
232-
redis_conn.set('sql_executed_result', json.dumps(executed_result))
233-
redis_conn.close()
256+
r.set('sql_executed_result', json.dumps(executed_result))
234257

235258
return EventSourceResponse(event_generator())
236259

@@ -251,8 +274,9 @@ async def multi_chat_stream(input: Optional[str] = None) \
251274
question = input
252275

253276
base_prompt = f"""你是一个资深客服,你的工作是判断用户是否提出了表述清楚、细节明确的问题。
254-
如果用户问题不清楚,你需要逐步引导用户补充细节,直到用户提出了清晰的问题。
255-
如果用户问题清晰,你需要复述概括用户需求,然后调用查询数据库函数查询数据库,返回SQL语句。
277+
直到用户问出了清晰明确的问题为止,你需要一直逐步引导用户补充细节。
278+
在用户问出清晰明确的问题之前,不能以“您的问题已经足够清楚:”开头来回复。
279+
如果用户问题清晰,你需要复述概括用户需求,然后提醒用户可以直接点击查询按钮,也可以继续聊天。
256280
你的数据库是一个以太坊的交易数据表(mysql),表名是 `eth20230701`,具有以下字段:
257281
- `Index`:int(11)
258282
- `TxHash`:varchar(64)
@@ -274,7 +298,10 @@ async def multi_chat_stream(input: Optional[str] = None) \
274298
3.一个表述清楚、细节明确的问题,应该在时间上是具体的,在需求上也是可被理解转述的。如果不满足这两点,就需要引导用户补充细节。
275299
4.引导用户补充细节时,针对用户提问中模糊的部分进行引导提问,而不是直接要求用户提供完整的问题。
276300
5.在用户补充细节后或者问题本身足够清楚,你需要根据多轮对话内容,转述概括用户需求,
277-
然后询问用户是否用此信息查询数据库还是要继续对话增添信息。在这种情况下,你的回答必须以“您的问题已经足够清楚”开头。
301+
然后询问用户是否用此信息查询数据库还是要继续对话增添信息。在这种情况下,你的回答必须以“您的问题已经足够清楚:”开头。
302+
6.以“您的问题已经足够清楚:”开头回复意味着你认为用户的问题已经足够清楚,可以直接查询数据库了。
303+
7.在“您的问题已经足够清楚:”和“。”之间,转述概括用户问题。如果问题没有说完,不要出现句号。
304+
8.在句号后面,提醒用户可以直接用此信息查询数据库,也可以继续对话增添信息。
278305
"""
279306

280307
# 判断是不是第一次问
@@ -286,14 +313,28 @@ async def multi_chat_stream(input: Optional[str] = None) \
286313
# },
287314
# ],
288315
# }
289-
message = {"default": [{"role": "system", "content": base_prompt}]}
316+
result_tmp = []
317+
try:
318+
length = int(r.get("len"))
319+
except Exception as e:
320+
print("First round! Outputs: ", e)
321+
length = 0
322+
for i in range(length):
323+
# Get the JSON string from Redis and convert it back to a dictionary
324+
item = json.loads(r.get(f"conversation:{i}"))
325+
result_tmp.append(item)
326+
message = {"default": result_tmp}
327+
# print(message)
328+
if not message["default"]:
329+
message = {"default": [{"role": "system", "content": base_prompt}]}
290330
if question is not None:
291331
message["default"].append({"role": "user", "content": question})
292332
else:
293333
1 # todo
294334

335+
print(message)
336+
295337
chat_msg = defaultdict(str)
296-
print(chat_msg)
297338

298339
async def event_generator():
299340
"""生成事件,获取流信息"""
@@ -324,10 +365,13 @@ async def event_generator():
324365
# 判断函数调用
325366
# else:
326367
message["default"].append({"role": "assistant", "content": answer})
327-
redis_conn = redis.Redis(host='localhost', port=6379, db=1)
328-
for item in message["default"]:
329-
redis_conn.hset("rhash", item["role"], item["content"])
330-
redis_conn.close()
368+
# todo,格式
369+
r.set("len", len(message["default"]))
370+
for i, item in enumerate(message["default"]):
371+
# Convert the dictionary to a JSON string and store it in Redis
372+
r.set(f"conversation:{i}", json.dumps(item))
373+
# for item in message["default"]: # todo,可能重复存入之前对话信息
374+
# r.hset("rhash", item["role"], item["content"])
331375
if "您的问题已经足够清楚" in answer:
332376
# todo
333377
# 调用查询数据库函数
@@ -343,9 +387,11 @@ async def event_generator():
343387
# redis_conn = redis.Redis(host='localhost', port=6379, db=1)
344388
# redis_conn.hset("rhash", "assistant", executed_result)
345389
# redis_conn.close()
346-
pass
390+
print("清楚")
391+
r.set("isClear", "1")
347392
else:
348-
pass
393+
print("不清楚")
394+
r.set("isClear", "0")
349395

350396
# if executed_result_dict["code"] == 200:
351397
# executed_result = executed_result_dict["data"]
@@ -362,16 +408,21 @@ def get_multichat_result():
362408
"""
363409
前端通过获取多轮对话,从redis中返回
364410
"""
365-
redis_conn = redis.Redis(host='localhost', port=6379, db=1)
366-
result = redis_conn.hgetall('rhash')
367-
redis_conn.close()
368-
new_list = []
369-
for k, v in result.items():
370-
new_list.append({k.decode(): v.decode()})
371-
# todo,这里的decode是为了兼容redis的返回值,后面可以去掉
372-
final_dict = {"default": new_list}
411+
# result = r.hgetall('rhash')
412+
# new_list = []
413+
# for k, v in result.items():
414+
# new_list.append({k.decode(): v.decode()})
415+
# # todo,这里的decode是为了兼容redis的返回值,后面可以去掉
416+
# final_dict = {"default": new_list, "isClear": r.get("isClear")}
417+
result_tmp = []
418+
length = int(r.get("len"))
419+
for i in range(length):
420+
# Get the JSON string from Redis and convert it back to a dictionary
421+
item = json.loads(r.get(f"conversation:{i}"))
422+
result_tmp.append(item)
423+
result = {"default": result_tmp}
424+
result_tmp.append({"isClear": r.get("isClear").decode()})
373425
if result is not None:
374-
result = final_dict
375426
print(result)
376427
if isinstance(result, dict):
377428
# 正确结果
@@ -396,10 +447,8 @@ def get_executed_result():
396447
"""
397448
前端通过获取SQL执行结果,从redis中返回
398449
"""
399-
redis_conn = redis.Redis(host='localhost', port=6379, db=0)
400-
result = redis_conn.get('sql_executed_result')
401-
redis_conn.delete('sql_executed_result')
402-
redis_conn.close()
450+
result = r.get('sql_executed_result')
451+
r.delete('sql_executed_result')
403452
if result is not None:
404453
result = json.loads(result)
405454
print(result)

0 commit comments

Comments
 (0)