Skip to content

Commit 6b353da

Browse files
committed
v0.2.0.2 bug fix
1 parent 2ee7a6f commit 6b353da

File tree

6 files changed

+41
-68
lines changed

6 files changed

+41
-68
lines changed

main.py renamed to src/main.py

Lines changed: 24 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
# import traceback
66
import json
77
import os
8-
9-
# import pprint
108
from collections import defaultdict
11-
from typing import AsyncGenerator, Dict, List, Optional
9+
from typing import AsyncGenerator, List, Optional
1210

1311
import redis
1412
import uvicorn
@@ -25,7 +23,7 @@
2523
# from starlette.middleware.sessions import SessionMiddleware
2624
from starlette.responses import JSONResponse
2725

28-
import dbtool
26+
from utils import dbtool
2927

3028
# from setting import cretKey, uuid_str, dumps
3129

@@ -53,17 +51,17 @@ class ChatInput(BaseModel):
5351
input: str
5452

5553

56-
class Conversation:
57-
"""
58-
存储对话信息
59-
"""
60-
61-
def __init__(self):
62-
self.conversation_history: List[Dict] = []
63-
64-
def add_message(self, role, content):
65-
message = {"role": role, "content": content}
66-
self.conversation_history.append(message)
54+
# class Conversation:
55+
# """
56+
# 存储对话信息
57+
# """
58+
#
59+
# def __init__(self):
60+
# self.conversation_history: List[Dict] = []
61+
#
62+
# def add_message(self, role, content):
63+
# message = {"role": role, "content": content}
64+
# self.conversation_history.append(message)
6765

6866

6967
redis_pool = redis.ConnectionPool(host='localhost', port=6379, db=1)
@@ -103,12 +101,12 @@ def refresh_page(r: redis.Redis = Depends(get_redis)):
103101

104102
@app.get("/")
105103
def read_root():
106-
return {"Hello": "World"}
104+
return {"message": "Hello World"}
107105

108106

109-
class Chatbot:
110-
def __init__(self):
111-
self.conversation = Conversation()
107+
# class Chatbot:
108+
# def __init__(self):
109+
# self.conversation = Conversation()
112110

113111

114112
async def request(val: List[dict[str, str]], call_function: bool) \
@@ -221,7 +219,6 @@ async def chat_stream(input: Optional[str] = None) -> EventSourceResponse:
221219

222220
message = [{"role": "user", "content": base_prompt}]
223221
chat_msg = defaultdict(str)
224-
print(chat_msg)
225222

226223
async def event_generator():
227224
"""生成事件,获取流信息"""
@@ -305,23 +302,14 @@ async def multi_chat_stream(input: Optional[str] = None) \
305302
9.在句号后面,提醒用户可以直接用此信息查询数据库,也可以继续对话增添或修改信息。
306303
"""
307304

308-
# 判断是不是第一次问
309-
# response = {
310-
# "default": [
311-
# {
312-
# "role": "system",
313-
# "content": system_prompt,
314-
# },
315-
# ],
316-
# }
317305
result_tmp = []
318306
try:
319307
length = int(r.get("len"))
320308
except Exception as e:
321309
print("First round! Outputs: ", e)
322310
length = 0
323311
for i in range(length):
324-
# Get the JSON string from Redis and convert it back to a dictionary
312+
# 将字符串转化为Json
325313
item = json.loads(r.get(f"conversation:{i}"))
326314
result_tmp.append(item)
327315
message = {"default": result_tmp}
@@ -330,8 +318,7 @@ async def multi_chat_stream(input: Optional[str] = None) \
330318
message = {"default": [{"role": "system", "content": base_prompt}]}
331319
if question is not None:
332320
message["default"].append({"role": "user", "content": question})
333-
else:
334-
1 # todo
321+
# 输入为空的判定已经在前端实现
335322

336323
print(message)
337324

@@ -345,55 +332,29 @@ async def event_generator():
345332
chat_msg["role"] = delta.get("role")
346333
if delta.get("content"):
347334
chat_msg["content"] += delta.get("content")
348-
# if "function_call" in delta:
349-
# pprint.pprint(delta.get("function_call"))
350-
# todo
351-
# function_call_res += delta.get("function_call").
352-
# get("data")
353-
# redis_conn = redis.Redis(host='localhost', port=6379,
354-
# db=1)
355-
# redis_conn.delete('rhash')
356-
# redis_conn.close()
357-
# break
358335
yield dict(id=None, event=None, data=json.dumps(chat_msg))
359336
except Exception as e:
360337
print("error: ", e)
361338

362339
answer = chat_msg["content"]
363340
print(answer)
364-
# if function_call_res:
365-
# await chat_stream("222")
366-
# 判断函数调用
367-
# else:
368341
message["default"].append({"role": "assistant", "content": answer})
369342
# todo,格式
370343
r.set("len", len(message["default"]))
371344
for i, item in enumerate(message["default"]):
372-
# Convert the dictionary to a JSON string and store it in Redis
345+
# Json转为字符串,然后才能放进redis
373346
r.set(f"conversation:{i}", json.dumps(item))
374347
# for item in message["default"]: # todo,可能重复存入之前对话信息
375348
# r.hset("rhash", item["role"], item["content"])
376349
if "您的问题已经足够清楚" in answer:
377-
# todo
378-
# 调用查询数据库函数
379-
# executed_result_dict = get_sql_execute_result(sql_result)
380-
# print(executed_result_dict)
381-
# if executed_result_dict["code"] == 200:
382-
# executed_result = executed_result_dict["data"]
383-
# draw_charts(executed_result)
384-
# else:
385-
# executed_result = executed_result_dict
386-
# message["default"].append({"role": "assistant", "content":
387-
# executed_result})
388-
# redis_conn = redis.Redis(host='localhost', port=6379, db=1)
389-
# redis_conn.hset("rhash", "assistant", executed_result)
390-
# redis_conn.close()
350+
# 是否调用查询数据库函数,选择权交还给前端用户
391351
print("清楚")
392352
r.set("isClear", "1")
393353
else:
394354
print("不清楚")
395355
r.set("isClear", "0")
396356

357+
# todo,画图
397358
# if executed_result_dict["code"] == 200:
398359
# executed_result = executed_result_dict["data"]
399360
#
@@ -409,16 +370,11 @@ def get_multichat_result():
409370
"""
410371
前端通过获取多轮对话,从redis中返回
411372
"""
412-
# result = r.hgetall('rhash')
413-
# new_list = []
414-
# for k, v in result.items():
415-
# new_list.append({k.decode(): v.decode()})
416-
# # todo,这里的decode是为了兼容redis的返回值,后面可以去掉
417-
# final_dict = {"default": new_list, "isClear": r.get("isClear")}
373+
# 不要用r.hgetall("rhash")。哈希表此处不适合存储对话信息
418374
result_tmp = []
419375
length = int(r.get("len"))
420376
for i in range(length):
421-
# Get the JSON string from Redis and convert it back to a dictionary
377+
# 从字符串转回Json
422378
item = json.loads(r.get(f"conversation:{i}"))
423379
result_tmp.append(item)
424380
result = {"default": result_tmp}

test/test_main.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
import sys
3+
4+
from fastapi.testclient import TestClient
5+
6+
from src.main import app # 导入你的FastAPI实例
7+
8+
print(f"Current working directory: {os.getcwd()}")
9+
print(f"sys.path: {sys.path}")
10+
11+
client = TestClient(app)
12+
13+
14+
def test_read_main():
15+
response = client.get("/")
16+
assert response.status_code == 200
17+
assert response.json() == {"message": "Hello World"}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)