13
13
import redis
14
14
import uvicorn
15
15
from dotenv import load_dotenv
16
- from fastapi import FastAPI
16
+ from fastapi import Depends , FastAPI
17
17
from fastapi .middleware .cors import CORSMiddleware
18
18
19
19
# from fastapi.responses import HTMLResponse
39
39
db_password = os .getenv ("DB_PASSWORD" )
40
40
41
41
GPT_MODEL = "gpt-4"
42
+ GPT_TEMPERATURE = 0.3
42
43
43
44
44
45
class DecimalEncoder (json .JSONEncoder ):
@@ -56,6 +57,7 @@ class Conversation:
56
57
"""
57
58
存储对话信息
58
59
"""
60
+
59
61
def __init__ (self ):
60
62
self .conversation_history : List [Dict ] = []
61
63
@@ -64,6 +66,20 @@ def add_message(self, role, content):
64
66
self .conversation_history .append (message )
65
67
66
68
69
+ redis_pool = redis .ConnectionPool (host = 'localhost' , port = 6379 , db = 1 )
70
+ r = redis .Redis (connection_pool = redis_pool )
71
+
72
+
73
+ def get_redis ():
74
+ """获取redis db=1的连接,以待刷新对话缓存"""
75
+ try :
76
+ yield r
77
+ except Exception as e :
78
+ print ("error: " , e )
79
+ # finally:
80
+ # r.close()
81
+
82
+
67
83
app = FastAPI ()
68
84
69
85
# 添加 CORS 中间件,允许跨域
@@ -76,6 +92,15 @@ def add_message(self, role, content):
76
92
)
77
93
78
94
95
+ @app .get ("/refresh" )
96
+ def refresh_page (r : redis .Redis = Depends (get_redis )):
97
+ """刷新页面,清空redis缓存"""
98
+ r .flushdb ()
99
+ # redis_pool = redis.ConnectionPool(host='localhost', port=6379, db=1)
100
+ # r = redis.Redis(connection_pool=redis_pool)
101
+ return {"message" : "Redis db=1 cache cleared!" }
102
+
103
+
79
104
@app .get ("/" )
80
105
def read_root ():
81
106
return {"Hello" : "World" }
@@ -101,7 +126,7 @@ async def request(val: List[dict[str, str]], call_function: bool) \
101
126
"model" : GPT_MODEL ,
102
127
"messages" : val ,
103
128
"max_tokens" : 3000 ,
104
- "temperature" : 0.5 ,
129
+ "temperature" : GPT_TEMPERATURE ,
105
130
"top_p" : 1 ,
106
131
"n" : 1 ,
107
132
"stream" : True ,
@@ -126,7 +151,7 @@ async def request(val: List[dict[str, str]], call_function: bool) \
126
151
# ]
127
152
async with AsyncClient () as client :
128
153
async with client .stream (
129
- "POST" , url , headers = headers , json = params , timeout = 60
154
+ "POST" , url , headers = headers , json = params , timeout = 60
130
155
) as response :
131
156
async for line in response .aiter_lines ():
132
157
if not line .strip ():
@@ -222,15 +247,13 @@ async def event_generator():
222
247
223
248
# 将执行结果存入redis,在本async方法中,无法直接返回执行结果,因为返回的是一个异步生成器
224
249
# 因为对应前端方法,也不适合需要传参的asyncio的Future或者Queue
225
- redis_conn = redis .Redis (host = 'localhost' , port = 6379 , db = 0 )
226
250
if "Value" in sql_result or "GasFee" in sql_result or "GasPrice" \
227
251
in sql_result :
228
252
# 处理Decimal类型
229
- redis_conn .set ('sql_executed_result' ,
230
- json .dumps (executed_result , cls = DecimalEncoder ))
253
+ r .set ('sql_executed_result' ,
254
+ json .dumps (executed_result , cls = DecimalEncoder ))
231
255
else :
232
- redis_conn .set ('sql_executed_result' , json .dumps (executed_result ))
233
- redis_conn .close ()
256
+ r .set ('sql_executed_result' , json .dumps (executed_result ))
234
257
235
258
return EventSourceResponse (event_generator ())
236
259
@@ -251,8 +274,9 @@ async def multi_chat_stream(input: Optional[str] = None) \
251
274
question = input
252
275
253
276
base_prompt = f"""你是一个资深客服,你的工作是判断用户是否提出了表述清楚、细节明确的问题。
254
- 如果用户问题不清楚,你需要逐步引导用户补充细节,直到用户提出了清晰的问题。
255
- 如果用户问题清晰,你需要复述概括用户需求,然后调用查询数据库函数查询数据库,返回SQL语句。
277
+ 直到用户问出了清晰明确的问题为止,你需要一直逐步引导用户补充细节。
278
+ 在用户问出清晰明确的问题之前,不能以“您的问题已经足够清楚:”开头来回复。
279
+ 如果用户问题清晰,你需要复述概括用户需求,然后提醒用户可以直接点击查询按钮,也可以继续聊天。
256
280
你的数据库是一个以太坊的交易数据表(mysql),表名是 `eth20230701`,具有以下字段:
257
281
- `Index`:int(11)
258
282
- `TxHash`:varchar(64)
@@ -274,7 +298,10 @@ async def multi_chat_stream(input: Optional[str] = None) \
274
298
3.一个表述清楚、细节明确的问题,应该在时间上是具体的,在需求上也是可被理解转述的。如果不满足这两点,就需要引导用户补充细节。
275
299
4.引导用户补充细节时,针对用户提问中模糊的部分进行引导提问,而不是直接要求用户提供完整的问题。
276
300
5.在用户补充细节后或者问题本身足够清楚,你需要根据多轮对话内容,转述概括用户需求,
277
- 然后询问用户是否用此信息查询数据库还是要继续对话增添信息。在这种情况下,你的回答必须以“您的问题已经足够清楚”开头。
301
+ 然后询问用户是否用此信息查询数据库还是要继续对话增添信息。在这种情况下,你的回答必须以“您的问题已经足够清楚:”开头。
302
+ 6.以“您的问题已经足够清楚:”开头回复意味着你认为用户的问题已经足够清楚,可以直接查询数据库了。
303
+ 7.在“您的问题已经足够清楚:”和“。”之间,转述概括用户问题。如果问题没有说完,不要出现句号。
304
+ 8.在句号后面,提醒用户可以直接用此信息查询数据库,也可以继续对话增添信息。
278
305
"""
279
306
280
307
# 判断是不是第一次问
@@ -286,14 +313,28 @@ async def multi_chat_stream(input: Optional[str] = None) \
286
313
# },
287
314
# ],
288
315
# }
289
- message = {"default" : [{"role" : "system" , "content" : base_prompt }]}
316
+ result_tmp = []
317
+ try :
318
+ length = int (r .get ("len" ))
319
+ except Exception as e :
320
+ print ("First round! Outputs: " , e )
321
+ length = 0
322
+ for i in range (length ):
323
+ # Get the JSON string from Redis and convert it back to a dictionary
324
+ item = json .loads (r .get (f"conversation:{ i } " ))
325
+ result_tmp .append (item )
326
+ message = {"default" : result_tmp }
327
+ # print(message)
328
+ if not message ["default" ]:
329
+ message = {"default" : [{"role" : "system" , "content" : base_prompt }]}
290
330
if question is not None :
291
331
message ["default" ].append ({"role" : "user" , "content" : question })
292
332
else :
293
333
1 # todo
294
334
335
+ print (message )
336
+
295
337
chat_msg = defaultdict (str )
296
- print (chat_msg )
297
338
298
339
async def event_generator ():
299
340
"""生成事件,获取流信息"""
@@ -324,10 +365,13 @@ async def event_generator():
324
365
# 判断函数调用
325
366
# else:
326
367
message ["default" ].append ({"role" : "assistant" , "content" : answer })
327
- redis_conn = redis .Redis (host = 'localhost' , port = 6379 , db = 1 )
328
- for item in message ["default" ]:
329
- redis_conn .hset ("rhash" , item ["role" ], item ["content" ])
330
- redis_conn .close ()
368
+ # todo,格式
369
+ r .set ("len" , len (message ["default" ]))
370
+ for i , item in enumerate (message ["default" ]):
371
+ # Convert the dictionary to a JSON string and store it in Redis
372
+ r .set (f"conversation:{ i } " , json .dumps (item ))
373
+ # for item in message["default"]: # todo,可能重复存入之前对话信息
374
+ # r.hset("rhash", item["role"], item["content"])
331
375
if "您的问题已经足够清楚" in answer :
332
376
# todo
333
377
# 调用查询数据库函数
@@ -343,9 +387,11 @@ async def event_generator():
343
387
# redis_conn = redis.Redis(host='localhost', port=6379, db=1)
344
388
# redis_conn.hset("rhash", "assistant", executed_result)
345
389
# redis_conn.close()
346
- pass
390
+ print ("清楚" )
391
+ r .set ("isClear" , "1" )
347
392
else :
348
- pass
393
+ print ("不清楚" )
394
+ r .set ("isClear" , "0" )
349
395
350
396
# if executed_result_dict["code"] == 200:
351
397
# executed_result = executed_result_dict["data"]
@@ -362,16 +408,21 @@ def get_multichat_result():
362
408
"""
363
409
前端通过获取多轮对话,从redis中返回
364
410
"""
365
- redis_conn = redis .Redis (host = 'localhost' , port = 6379 , db = 1 )
366
- result = redis_conn .hgetall ('rhash' )
367
- redis_conn .close ()
368
- new_list = []
369
- for k , v in result .items ():
370
- new_list .append ({k .decode (): v .decode ()})
371
- # todo,这里的decode是为了兼容redis的返回值,后面可以去掉
372
- final_dict = {"default" : new_list }
411
+ # result = r.hgetall('rhash')
412
+ # new_list = []
413
+ # for k, v in result.items():
414
+ # new_list.append({k.decode(): v.decode()})
415
+ # # todo,这里的decode是为了兼容redis的返回值,后面可以去掉
416
+ # final_dict = {"default": new_list, "isClear": r.get("isClear")}
417
+ result_tmp = []
418
+ length = int (r .get ("len" ))
419
+ for i in range (length ):
420
+ # Get the JSON string from Redis and convert it back to a dictionary
421
+ item = json .loads (r .get (f"conversation:{ i } " ))
422
+ result_tmp .append (item )
423
+ result = {"default" : result_tmp }
424
+ result_tmp .append ({"isClear" : r .get ("isClear" ).decode ()})
373
425
if result is not None :
374
- result = final_dict
375
426
print (result )
376
427
if isinstance (result , dict ):
377
428
# 正确结果
@@ -396,10 +447,8 @@ def get_executed_result():
396
447
"""
397
448
前端通过获取SQL执行结果,从redis中返回
398
449
"""
399
- redis_conn = redis .Redis (host = 'localhost' , port = 6379 , db = 0 )
400
- result = redis_conn .get ('sql_executed_result' )
401
- redis_conn .delete ('sql_executed_result' )
402
- redis_conn .close ()
450
+ result = r .get ('sql_executed_result' )
451
+ r .delete ('sql_executed_result' )
403
452
if result is not None :
404
453
result = json .loads (result )
405
454
print (result )
0 commit comments