5
5
# import traceback
6
6
import json
7
7
import os
8
-
9
- # import pprint
10
8
from collections import defaultdict
11
- from typing import AsyncGenerator , Dict , List , Optional
9
+ from typing import AsyncGenerator , List , Optional
12
10
13
11
import redis
14
12
import uvicorn
25
23
# from starlette.middleware.sessions import SessionMiddleware
26
24
from starlette .responses import JSONResponse
27
25
28
- import dbtool
26
+ from utils import dbtool
29
27
30
28
# from setting import cretKey, uuid_str, dumps
31
29
@@ -53,17 +51,17 @@ class ChatInput(BaseModel):
53
51
input : str
54
52
55
53
56
- class Conversation :
57
- """
58
- 存储对话信息
59
- """
60
-
61
- def __init__ (self ):
62
- self .conversation_history : List [Dict ] = []
63
-
64
- def add_message (self , role , content ):
65
- message = {"role" : role , "content" : content }
66
- self .conversation_history .append (message )
54
+ # class Conversation:
55
+ # """
56
+ # 存储对话信息
57
+ # """
58
+ #
59
+ # def __init__(self):
60
+ # self.conversation_history: List[Dict] = []
61
+ #
62
+ # def add_message(self, role, content):
63
+ # message = {"role": role, "content": content}
64
+ # self.conversation_history.append(message)
67
65
68
66
69
67
redis_pool = redis .ConnectionPool (host = 'localhost' , port = 6379 , db = 1 )
@@ -103,12 +101,12 @@ def refresh_page(r: redis.Redis = Depends(get_redis)):
103
101
104
102
@app .get ("/" )
105
103
def read_root ():
106
- return {"Hello " : "World" }
104
+ return {"message " : "Hello World" }
107
105
108
106
109
- class Chatbot :
110
- def __init__ (self ):
111
- self .conversation = Conversation ()
107
+ # class Chatbot:
108
+ # def __init__(self):
109
+ # self.conversation = Conversation()
112
110
113
111
114
112
async def request (val : List [dict [str , str ]], call_function : bool ) \
@@ -221,7 +219,6 @@ async def chat_stream(input: Optional[str] = None) -> EventSourceResponse:
221
219
222
220
message = [{"role" : "user" , "content" : base_prompt }]
223
221
chat_msg = defaultdict (str )
224
- print (chat_msg )
225
222
226
223
async def event_generator ():
227
224
"""生成事件,获取流信息"""
@@ -305,23 +302,14 @@ async def multi_chat_stream(input: Optional[str] = None) \
305
302
9.在句号后面,提醒用户可以直接用此信息查询数据库,也可以继续对话增添或修改信息。
306
303
"""
307
304
308
- # 判断是不是第一次问
309
- # response = {
310
- # "default": [
311
- # {
312
- # "role": "system",
313
- # "content": system_prompt,
314
- # },
315
- # ],
316
- # }
317
305
result_tmp = []
318
306
try :
319
307
length = int (r .get ("len" ))
320
308
except Exception as e :
321
309
print ("First round! Outputs: " , e )
322
310
length = 0
323
311
for i in range (length ):
324
- # Get the JSON string from Redis and convert it back to a dictionary
312
+ # 将字符串转化为Json
325
313
item = json .loads (r .get (f"conversation:{ i } " ))
326
314
result_tmp .append (item )
327
315
message = {"default" : result_tmp }
@@ -330,8 +318,7 @@ async def multi_chat_stream(input: Optional[str] = None) \
330
318
message = {"default" : [{"role" : "system" , "content" : base_prompt }]}
331
319
if question is not None :
332
320
message ["default" ].append ({"role" : "user" , "content" : question })
333
- else :
334
- 1 # todo
321
+ # 输入为空的判定已经在前端实现
335
322
336
323
print (message )
337
324
@@ -345,55 +332,29 @@ async def event_generator():
345
332
chat_msg ["role" ] = delta .get ("role" )
346
333
if delta .get ("content" ):
347
334
chat_msg ["content" ] += delta .get ("content" )
348
- # if "function_call" in delta:
349
- # pprint.pprint(delta.get("function_call"))
350
- # todo
351
- # function_call_res += delta.get("function_call").
352
- # get("data")
353
- # redis_conn = redis.Redis(host='localhost', port=6379,
354
- # db=1)
355
- # redis_conn.delete('rhash')
356
- # redis_conn.close()
357
- # break
358
335
yield dict (id = None , event = None , data = json .dumps (chat_msg ))
359
336
except Exception as e :
360
337
print ("error: " , e )
361
338
362
339
answer = chat_msg ["content" ]
363
340
print (answer )
364
- # if function_call_res:
365
- # await chat_stream("222")
366
- # 判断函数调用
367
- # else:
368
341
message ["default" ].append ({"role" : "assistant" , "content" : answer })
369
342
# todo,格式
370
343
r .set ("len" , len (message ["default" ]))
371
344
for i , item in enumerate (message ["default" ]):
372
- # Convert the dictionary to a JSON string and store it in Redis
345
+ # Json转为字符串,然后才能放进redis
373
346
r .set (f"conversation:{ i } " , json .dumps (item ))
374
347
# for item in message["default"]: # todo,可能重复存入之前对话信息
375
348
# r.hset("rhash", item["role"], item["content"])
376
349
if "您的问题已经足够清楚" in answer :
377
- # todo
378
- # 调用查询数据库函数
379
- # executed_result_dict = get_sql_execute_result(sql_result)
380
- # print(executed_result_dict)
381
- # if executed_result_dict["code"] == 200:
382
- # executed_result = executed_result_dict["data"]
383
- # draw_charts(executed_result)
384
- # else:
385
- # executed_result = executed_result_dict
386
- # message["default"].append({"role": "assistant", "content":
387
- # executed_result})
388
- # redis_conn = redis.Redis(host='localhost', port=6379, db=1)
389
- # redis_conn.hset("rhash", "assistant", executed_result)
390
- # redis_conn.close()
350
+ # 是否调用查询数据库函数,选择权交还给前端用户
391
351
print ("清楚" )
392
352
r .set ("isClear" , "1" )
393
353
else :
394
354
print ("不清楚" )
395
355
r .set ("isClear" , "0" )
396
356
357
+ # todo,画图
397
358
# if executed_result_dict["code"] == 200:
398
359
# executed_result = executed_result_dict["data"]
399
360
#
@@ -409,16 +370,11 @@ def get_multichat_result():
409
370
"""
410
371
前端通过获取多轮对话,从redis中返回
411
372
"""
412
- # result = r.hgetall('rhash')
413
- # new_list = []
414
- # for k, v in result.items():
415
- # new_list.append({k.decode(): v.decode()})
416
- # # todo,这里的decode是为了兼容redis的返回值,后面可以去掉
417
- # final_dict = {"default": new_list, "isClear": r.get("isClear")}
373
+ # 不要用r.hgetall("rhash")。哈希表此处不适合存储对话信息
418
374
result_tmp = []
419
375
length = int (r .get ("len" ))
420
376
for i in range (length ):
421
- # Get the JSON string from Redis and convert it back to a dictionary
377
+ # 从字符串转回Json
422
378
item = json .loads (r .get (f"conversation:{ i } " ))
423
379
result_tmp .append (item )
424
380
result = {"default" : result_tmp }
0 commit comments