-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
65 lines (49 loc) · 1.89 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Union
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from model import GPTQModel
from langchain.prompts import load_prompt
from config import HUMAN_PREFIX, AI_PREFIX, GENERATE_PARAMS
from langchain.memory import ConversationBufferMemory
app = FastAPI()
# 解决跨域问题
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# 载入模型
from config import AUTO_TYPE, MODEL_PARAMS
gptq = GPTQModel(AUTO_TYPE, **MODEL_PARAMS)
class GenerateParams(BaseModel):
prompt: str
params: Union[dict, None] = GENERATE_PARAMS
@app.post("/generate/")
async def generate(item: GenerateParams):
return gptq(item.prompt, streaming=False, **item.params)
@app.post("/streaming_generate/")
async def streaming_generate(item: GenerateParams):
print(item)
return EventSourceResponse(
gptq(item.prompt, streaming=True, **item.params), media_type="text/event-stream"
)
@app.post("/chat/")
async def chat(history: list[list[str]]):
# 构建prompt
memory = ConversationBufferMemory(human_prefix=HUMAN_PREFIX, ai_prefix=AI_PREFIX)
for human_text, ai_text in history[-10:-1]:
memory.save_context({'input':human_text}, {'output':ai_text})
history_text = memory.buffer
template = load_prompt("prompts/conversation.json", )
prompt = template.format(human_prefix=HUMAN_PREFIX, ai_prefix=AI_PREFIX, history=history_text, input=history[-1][0])
# 构建生成参数
params = dict(min_length=0, max_length=2048, num_beams=10, temperature=0.1, top_p=0.75, top_k=40)
return EventSourceResponse(
gptq(prompt, streaming=True, **params), media_type="text/event-stream"
)
@app.post("/embed/")
async def embed(item: GenerateParams):
return gptq.embed(item.prompt)