Skip to content

Commit 0423e4c

Browse files
authored
Add support outlines
Signed-off-by: GitHub <[email protected]>
1 parent 991ecce commit 0423e4c

File tree

9 files changed

+649
-0
lines changed

9 files changed

+649
-0
lines changed

.vscode/launch.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@
2828
"LIBRARY_PATH": "${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz",
2929
"DEBUG": "true"
3030
}
31+
},
32+
{
33+
"name":"Launch outlines",
34+
"type": "python",
35+
"request": "launch",
36+
"program": "${workspaceFolder}/backend/python/backend_outlines/backend_outlines.py",
37+
"console": "integratedTerminal",
38+
"justMyCode": true,
39+
"env": {}
3140
}
3241
]
3342
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.PONY: outlines
2+
outlines:
3+
@echo "Creating virtual environment..."
4+
@conda env create --name outlines --file outlines.yml
5+
@echo "Virtual environment created."
6+
7+
.PONY: run
8+
run:
9+
@echo "Running outlines..."
10+
bash run.sh
11+
@echo "outlines run."
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Creating a separate environment for the outlines project
2+
3+
```
4+
make outlines
5+
```
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
This is the extra gRPC server for outlines of LocalAI
3+
"""
4+
from concurrent import futures
5+
import argparse
6+
import os
7+
import signal
8+
import sys
9+
import time
10+
11+
import backend_pb2
12+
import backend_pb2_grpc
13+
14+
import grpc
15+
16+
import outlines.text.generate as generate
17+
import outlines.models as models
18+
19+
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
20+
21+
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
22+
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
23+
24+
# Implement the BackendServicer class with the service methods
25+
class BackendServicer(backend_pb2_grpc.BackendServicer):
26+
"""
27+
BackendServicer is the class that implements the gRPC service
28+
"""
29+
def Health(self, request, context):
30+
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
31+
32+
def LoadModel(self, request, context):
33+
try:
34+
# model should be name of the model, e.g. gpt2
35+
if request.Model == "":
36+
return backend_pb2.Result(success=False, message="Model name is empty")
37+
# It includes cache of the model, we do not need to add cache here.
38+
self.model = models.transformers(request.Model)
39+
except Exception as err:
40+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
41+
return backend_pb2.Result(message="Model loaded successfully", success=True)
42+
43+
def Predict(self, request, context):
44+
try:
45+
output=generate.continuation(self.model, stop=[str(request.StopPrompts)])(str(request.Prompt))
46+
except Exception as err:
47+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
48+
return backend_pb2.Result(message=bytes(output, encoding='utf-8'))
49+
50+
def serve(address):
51+
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
52+
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
53+
server.add_insecure_port(address)
54+
server.start()
55+
print("Server started. Listening on: " + address, file=sys.stderr)
56+
57+
# Define the signal handler function
58+
def signal_handler(sig, frame):
59+
print("Received termination signal. Shutting down...")
60+
server.stop(0)
61+
sys.exit(0)
62+
63+
# Set the signal handlers for SIGINT and SIGTERM
64+
signal.signal(signal.SIGINT, signal_handler)
65+
signal.signal(signal.SIGTERM, signal_handler)
66+
67+
try:
68+
while True:
69+
time.sleep(_ONE_DAY_IN_SECONDS)
70+
except KeyboardInterrupt:
71+
server.stop(0)
72+
73+
if __name__ == "__main__":
74+
parser = argparse.ArgumentParser(description="Run the gRPC server.")
75+
parser.add_argument(
76+
"--addr", default="localhost:50051", help="The address to bind the server to."
77+
)
78+
args = parser.parse_args()
79+
80+
serve(args.addr)

backend/python/backend_outlines/backend_pb2.py

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)