Skip to content

Commit

Permalink
FEAT: RESTful API (xorbitsai#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayini1119 authored Jul 6, 2023
1 parent 8e28844 commit 8997c1e
Show file tree
Hide file tree
Showing 4 changed files with 636 additions and 131 deletions.
104 changes: 99 additions & 5 deletions plexar/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@

import asyncio
import uuid
from typing import List, Optional, Tuple
from typing import Iterator, List, Optional, Tuple, Union

import requests
import xoscar as xo

from .core.model import ModelActor
from .core.service import SupervisorActor
from .isolation import Isolation
from .model import ModelSpec
from .model.llm.types import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
Completion,
CompletionChunk,
)


class Client:
Expand All @@ -44,7 +52,7 @@ def launch_model(
model_size_in_billions: Optional[int] = None,
model_format: Optional[str] = None,
quantization: Optional[str] = None,
**kwargs
**kwargs,
) -> str:
model_uid = self.gen_model_uid()

Expand All @@ -54,7 +62,7 @@ def launch_model(
model_size_in_billions=model_size_in_billions,
model_format=model_format,
quantization=quantization,
**kwargs
**kwargs,
)
self._isolation.call(coro)

Expand All @@ -73,6 +81,92 @@ def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
return self._isolation.call(coro)


class RESTfulClient:
def __init__(self, base_url):
self.base_url = base_url

@classmethod
def gen_model_uid(cls) -> str:
# generate a time-based uuid.
return str(uuid.uuid1())

def list_models(self) -> List[str]:
url = f"{self.base_url}/v1/models"

response = requests.get(url)
response_data = response.json()
return response_data

def launch_model(
self,
model_name: str,
model_size_in_billions: Optional[int] = None,
model_format: Optional[str] = None,
quantization: Optional[str] = None,
**kwargs,
) -> str:
url = f"{self.base_url}/v1/models"

model_uid = self.gen_model_uid()
payload = {
"model_uid": model_uid,
"model_name": model_name,
"model_size_in_billions": model_size_in_billions,
"model_format": model_format,
"quantization": quantization,
"kwargs": kwargs,
}
response = requests.post(url, json=payload)
response_data = response.json()
model_uid = response_data["model_uid"]
return model_uid

def terminate_model(self, model_uid: str):
url = f"{self.base_url}/v1/models/{model_uid}"

response = requests.delete(url)
if response.status_code != 200:
raise Exception(f"Error terminating the model.")

def generate(
self, model_uid: str, prompt: str, **kwargs
) -> Union[Completion, Iterator[CompletionChunk]]:
url = f"{self.base_url}/v1/completions"

request_body = {"model": model_uid, "prompt": prompt, **kwargs}
response = requests.post(url, json=request_body)
response_data = response.json()
return response_data

def chat(
self,
model_uid: str,
prompt: str,
system_prompt: Optional[str] = None,
chat_history: Optional[List[ChatCompletionMessage]] = None,
**kwargs,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
url = f"{self.base_url}/v1/chat/completions"

if chat_history is None:
chat_history = []

if chat_history and chat_history[0]["role"] == "system":
if system_prompt is not None:
chat_history[0]["content"] = system_prompt
else:
if system_prompt is not None:
chat_history.insert(
0, ChatCompletionMessage(role="system", content=system_prompt)
)

chat_history.append(ChatCompletionMessage(role="user", content=prompt))
request_body = {"model": model_uid, "messages": chat_history, **kwargs}
response = requests.post(url, json=request_body)
response_data = response.json()
return response_data


class AsyncClient:
def __init__(self, supervisor_address: str):
self._supervisor_address = supervisor_address
Expand All @@ -96,7 +190,7 @@ async def launch_model(
model_size_in_billions: Optional[int] = None,
model_format: Optional[str] = None,
quantization: Optional[str] = None,
**kwargs
**kwargs,
) -> str:
model_uid = self.gen_model_uid()

Expand All @@ -107,7 +201,7 @@ async def launch_model(
model_size_in_billions=model_size_in_billions,
model_format=model_format,
quantization=quantization,
**kwargs
**kwargs,
)
return model_uid

Expand Down
Loading

0 comments on commit 8997c1e

Please sign in to comment.