Skip to content

Commit c9d42c6

Browse files
authored
ENH: add Gradio ChatInterface chatbot to example (#208)
1 parent 1fe3bb2 commit c9d42c6

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

examples/gradio_chatinterface.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Dict, List
2+
3+
import gradio as gr
4+
5+
from xinference.client import Client
6+
7+
if __name__ == "__main__":
8+
import argparse
9+
import textwrap
10+
11+
parser = argparse.ArgumentParser(
12+
formatter_class=argparse.RawDescriptionHelpFormatter,
13+
epilog=textwrap.dedent(
14+
"""\
15+
instructions to run:
16+
1. Install Xinference and Llama-cpp-python
17+
2. Run 'xinference --host "localhost" --port 9997' in terminal
18+
3. Run this python file in new terminal window
19+
20+
e.g. (feel free to copy)
21+
python gradio_chatinterface.py \\
22+
--endpoint http://localhost:9997 \\
23+
--model_name vicuna-v1.3 \\
24+
--model_size_in_billions 7 \\
25+
--model_format ggmlv3 \\
26+
--quantization q2_K
27+
28+
If you decide to change the port number in step 2,
29+
please also change the endpoint in the arguments
30+
"""
31+
),
32+
)
33+
34+
parser.add_argument(
35+
"--endpoint", type=str, required=True, help="Xinference endpoint, required"
36+
)
37+
parser.add_argument(
38+
"--model_name", type=str, required=True, help="Name of the model, required"
39+
)
40+
parser.add_argument(
41+
"--model_size_in_billions",
42+
type=int,
43+
required=False,
44+
help="Size of the model in billions",
45+
)
46+
parser.add_argument(
47+
"--model_format",
48+
type=str,
49+
required=False,
50+
help="Format of the model",
51+
)
52+
parser.add_argument(
53+
"--quantization", type=str, required=False, help="Quantization of the model"
54+
)
55+
56+
args = parser.parse_args()
57+
58+
endpoint = args.endpoint
59+
model_name = args.model_name
60+
model_size_in_billions = args.model_size_in_billions
61+
model_format = args.model_format
62+
quantization = args.quantization
63+
64+
print(f"Xinference endpoint: {endpoint}")
65+
print(f"Model Name: {model_name}")
66+
print(f"Model Size (in billions): {model_size_in_billions}")
67+
print(f"Model Format: {model_format}")
68+
print(f"Quantization: {quantization}")
69+
70+
client = Client(endpoint)
71+
model_uid = client.launch_model(
72+
model_name,
73+
model_size_in_billions=model_size_in_billions,
74+
model_format=model_format,
75+
quantization=quantization,
76+
n_ctx=2048,
77+
)
78+
model = client.get_model(model_uid)
79+
80+
def flatten(matrix: List[List[str]]) -> List[str]:
81+
flat_list = []
82+
for row in matrix:
83+
flat_list += row
84+
return flat_list
85+
86+
def to_chat(lst: List[str]) -> List[Dict[str, str]]:
87+
res = []
88+
for i in range(len(lst)):
89+
role = "assistant" if i % 2 == 1 else "user"
90+
res.append(
91+
{
92+
"role": role,
93+
"content": lst[i],
94+
}
95+
)
96+
return res
97+
98+
def generate_wrapper(message: str, history: List[List[str]]) -> str:
99+
output = model.chat(
100+
prompt=message,
101+
chat_history=to_chat(flatten(history)),
102+
generate_config={"max_tokens": 512, "stream": False},
103+
)
104+
return output["choices"][0]["message"]["content"]
105+
106+
demo = gr.ChatInterface(
107+
fn=generate_wrapper,
108+
examples=[
109+
"Show me a two sentence horror story with a plot twist",
110+
"Generate a Haiku poem using trignometry as the central theme",
111+
"Write three sentences of scholarly description regarding a supernatural beast",
112+
"Prove there does not exist a largest integer",
113+
],
114+
title="Xinference Chat Bot",
115+
)
116+
demo.launch()

0 commit comments

Comments
 (0)