Skip to content

feat(CTransformers): add support to CTransformers #1248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions extra/grpc/c_transformers/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.PONY: ctransformers
ctransformers:
@echo "Creating virtual environment..."
@conda create -n ctransformers python=3.11 -y
@echo "Virtual environment created."

@echo "Activating virtual environment..."
@. activate ctransformers

@echo "Installing dependencies..."
@pip install grpcio==1.59.0 protobuf==4.24.4

# Install ctransformers from JLLLLLL's cuBLAS wheels will append cu117to version of ctransformer, this will cause creating from file failed.
@echo "Installing ctransformers..."
@pip install ctransformers==0.2.27 --prefer-binary --extra-index-url=https://jllllll.github.io/ctransformers-cuBLAS-wheels/AVX2/cu117

@echo "Deactivating virtual environment..."
@. deactivate
5 changes: 5 additions & 0 deletions extra/grpc/c_transformers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Creating a separate environment for ctransformers project

```
make ctransformers
```
61 changes: 61 additions & 0 deletions extra/grpc/c_transformers/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

363 changes: 363 additions & 0 deletions extra/grpc/c_transformers/backend_pb2_grpc.py

Large diffs are not rendered by default.

108 changes: 108 additions & 0 deletions extra/grpc/c_transformers/c_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
This is the extra gRPC server of LocalAI
"""

from __future__ import annotations
from typing import List
from concurrent import futures
import time
import argparse
import signal
import sys
import os

import grpc
import backend_pb2
import backend_pb2_grpc

from ctransformers import AutoModelForCausalLM, AutoConfig, Config

# Adapted from https://github.com/marella/ctransformers/tree/main#supported-models
# License: MIT
# Adapted by AIsuko
class ModelType:
GPT = "gpt2"
GPT_J_GPT4_ALL_J= "gptj"
GPT_NEOX_STABLE_LM = "gpt_neox"
FALCON= "falcon"
LLaMA_LLaMA2 = "llama"
MPT="mpt"
STAR_CODER_CHAT="gpt_bigcode"
DOLLY_V2="dolly-v2"
REPLIT="replit"

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))


class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
BackendServicer is the class that implements the gRPC service
"""
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
try:
model_path = request.Model
if not os.path.exists(model_path):
return backend_pb2.Result(success=False, message=f"Model path {model_path} does not exist")
model_type = request.ModelType
if model_type not in ModelType.__dict__.values():
return backend_pb2.Result(success=False, message=f"Model type {model_type} not supported")
llm = AutoModelForCausalLM.from_pretrained(model_file=model_path, model_type=model_type)
self.model=llm
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)

def Predict(self, request, context):
try:
generated_text=self.model(request.prompt)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message=bytes(generated_text), encoding="utf-8")

def PredictStream(self, request, context):
return super().PredictStream(request, context)

def TokenizeString(self, request, context):
try:
tokens: List[int]=self.model.tokenize(request.prompt, add_bos_token=False)
l=len(tokens)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.TokenizationResponse(length=l, tokens=tokens)

def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)

# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()

serve(args.addr)
14 changes: 14 additions & 0 deletions extra/grpc/c_transformers/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

##
## A bash script wrapper that runs the ctransformers server with conda

export PATH=$PATH:/opt/conda/bin

# Activate conda environment
source activate ctransformers

# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

python $DIR/c_transformers.py $@
1 change: 1 addition & 0 deletions extra/grpc/huggingface/backend_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ def __init__(self, channel):
self.PredictStream = channel.unary_stream(
'/backend.Backend/PredictStream',
request_serializer=backend__pb2.PredictOptions.SerializeToString,

response_deserializer=backend__pb2.Reply.FromString,
)
self.Embedding = channel.unary_unary(
7 changes: 0 additions & 7 deletions extra/requirements.txt

This file was deleted.