Skip to content

Commit

Permalink
FEAT: stream generation (xorbitsai#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Jun 26, 2023
1 parent 38c1a4e commit 143c90f
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 52 deletions.
60 changes: 60 additions & 0 deletions plexar/actor/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Generic, Iterator, TypeVar

import xoscar as xo


class IteratorActor(xo.Actor):
def __init__(self, it: Iterator[Any]):
super().__init__()
self._iter = it

def __iter__(self):
return IteratorWrapper(self.address, self.uid)

def next(self):
try:
return self._iter.__next__()
except StopIteration:
raise Exception("StopIteration")


T = TypeVar("T")


class IteratorWrapper(Generic[T]):
def __init__(self, iter_actor_addr: str, iter_actor_uid: str):
self._iter_actor_addr = iter_actor_addr
self._iter_actor_uid = iter_actor_uid
self._iter_actor_ref = None

def __aiter__(self):
return self

async def __anext__(self) -> T:
if self._iter_actor_ref is None:
self._iter_actor_ref = await xo.actor_ref(
address=self._iter_actor_addr, uid=self._iter_actor_uid
)

try:
assert self._iter_actor_ref is not None
return await self._iter_actor_ref.next()
except Exception as e:
if str(e) == "StopIteration":
await xo.destroy_actor(self._iter_actor_ref)
raise StopAsyncIteration
else:
raise
40 changes: 33 additions & 7 deletions plexar/actor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import TYPE_CHECKING, Any, Dict, Iterator

import xoscar as xo

from ..model.llm.core import Model
from .common import IteratorActor, IteratorWrapper

if TYPE_CHECKING:
from ..model.llm.core import Model


class ModelManagerActor(xo.Actor):
Expand All @@ -31,15 +33,39 @@ def get_model(self, model_uid: str):

class ModelActor(xo.Actor):
@classmethod
def gen_uid(cls, model: Model):
def gen_uid(cls, model: "Model"):
return f"{model.__class__}-model-actor"

def __init__(self, model: Model):
def __init__(self, model: "Model"):
super().__init__()
self._model = model

async def __post_create__(self):
self._model.load()

def __getattr__(self, item):
return getattr(self._model, item)
async def _create_iterator_actor(self, it: Iterator) -> IteratorWrapper:
uid = str(id(it))
await xo.create_actor(IteratorActor, address=self.address, uid=uid, it=it)
return IteratorWrapper(iter_actor_addr=self.address, iter_actor_uid=uid)

async def _wrap_iterator(self, ret: Any):
if hasattr(ret, "__iter__"):
return await self._create_iterator_actor(iter(ret))
else:
return ret

async def generate(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "generate"):
raise AttributeError("generate")

return self._wrap_iterator(
getattr(self._model, "generate")(prompt, *args, **kwargs)
)

async def chat(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "chat"):
raise AttributeError("chat")

return self._wrap_iterator(
getattr(self._model, "chat")(prompt, *args, **kwargs)
)
17 changes: 12 additions & 5 deletions plexar/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def model_list():
@click.option("--path", "-p")
def model_launch(path):
import asyncio
import textwrap
import sys

import xoscar as xo

Expand All @@ -48,7 +48,7 @@ def model_launch(path):
async def _run():
await xo.create_actor_pool(address="localhost:9999", n_process=1)

vu = VicunaUncensoredGgml(model_path=path, llamacpp_model_config={})
vu = VicunaUncensoredGgml(model_path=path)
vu_ref = await xo.create_actor(
ModelActor, address="localhost:9999", uid="vu", model=vu
)
Expand All @@ -58,9 +58,16 @@ async def _run():
if i == "exit":
break

completion = await vu_ref.chat(i)
text = "\n".join(textwrap.wrap(completion["text"], width=80))
print(f"Assistant:\n{text}")
print(f"Assistant:")
length = 0
async for chunk in await vu_ref.chat(i):
sys.stdout.write(chunk["text"])
sys.stdout.flush()
length += len(chunk["text"])
if length >= 80:
print()
length = 0
print()

loop = asyncio.get_event_loop()
loop.run_until_complete(_run())
Expand Down
Loading

0 comments on commit 143c90f

Please sign in to comment.