Skip to content

Commit

Permalink
FEAT: Support stable diffusion (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Oct 13, 2023
1 parent ab14aa9 commit 1fbcb80
Show file tree
Hide file tree
Showing 16 changed files with 515 additions and 7 deletions.
24 changes: 18 additions & 6 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
needs: lint
env:
CONDA_ENV: test
SELF_HOST_PYTHON: /root/miniconda3/envs/inference_test/bin/python
defaults:
run:
shell: bash -l {0}
Expand All @@ -67,6 +68,8 @@ jobs:
- { os: macos-latest, python-version: 3.10 }
- { os: windows-latest, python-version: 3.9 }
- { os: windows-latest, python-version: 3.10 }
include:
- { os: self-hosted, module: gpu, python-version: 3.9}

steps:
- name: Check out code
Expand All @@ -77,13 +80,15 @@ jobs:

- name: Set up conda ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
if: ${{ matrix.module != 'gpu' }}
with:
python-version: ${{ matrix.python-version }}
activate-environment: ${{ env.CONDA_ENV }}

- name: Install dependencies
env:
MODULE: ${{ matrix.module }}
if: ${{ matrix.module != 'gpu' }}
run: |
pip install llama-cpp-python>=0.2.0
pip install transformers
Expand All @@ -96,17 +101,24 @@ jobs:
pip install sentence-transformers
pip install s3fs
pip install modelscope
pip install diffusers
pip install -e ".[dev]"
working-directory: .

- name: Test with pytest
env:
MODULE: ${{ matrix.module }}
run: |
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py xinference
if [ "$MODULE" == "gpu" ]; then
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py
else
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py xinference
fi
working-directory: .
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ all =
tiktoken
sentence-transformers
vllm
diffusers
ggml =
llama-cpp-python>=0.2.0
ctransformers
Expand All @@ -93,6 +94,8 @@ vllm =
vllm
embedding =
sentence-transformers
image =
diffusers
doc =
ipython>=6.5.0
sphinx>=3.0.0,<5.0.0
Expand Down
35 changes: 35 additions & 0 deletions xinference/client/oscar/actor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Completion,
CompletionChunk,
Embedding,
ImageList,
LlamaCppGenerateConfig,
PytorchGenerateConfig,
)
Expand Down Expand Up @@ -180,6 +181,38 @@ def chat(
return self._isolation.call(coro)


class ImageModelHandle(ModelHandle):
def text_to_image(
self,
prompt: str,
n: int = 1,
size: str = "1024*1024",
response_format: str = "url",
**kwargs,
) -> "ImageList":
"""
Creates an image by the input text.
Parameters
----------
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
n (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. Must be between 1 and 10.
size (`str`, *optional*, defaults to `1024*1024`):
The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
response_format (`str`, *optional*, defaults to `url`):
The format in which the generated images are returned. Must be one of url or b64_json.
Returns
-------
ImageList
A list of image objects.
"""

coro = self._model_ref.text_to_image(prompt, n, size, response_format, **kwargs)
return self._isolation.call(coro)


class ActorClient:
def __init__(self, endpoint: str):
restful_client = Client(endpoint)
Expand Down Expand Up @@ -379,6 +412,8 @@ def get_model(self, model_uid: str) -> "ModelHandle":
raise ValueError(f"Unrecognized model ability: {desc['model_ability']}")
elif desc["model_type"] == "embedding":
return EmbeddingModelHandle(model_ref, self._isolation)
elif desc["model_type"] == "image":
return ImageModelHandle(model_ref, self._isolation)
else:
raise ValueError(f"Unknown model type:{desc['model_type']}")

Expand Down
47 changes: 47 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Completion,
CompletionChunk,
Embedding,
ImageList,
LlamaCppGenerateConfig,
PytorchGenerateConfig,
)
Expand Down Expand Up @@ -78,6 +79,50 @@ def create_embedding(self, input: Union[str, List[str]]) -> "Embedding":
return response_data


class RESTfulImageModelHandle(RESTfulModelHandle):
def text_to_image(
self,
prompt: str,
n: int = 1,
size: str = "1024*1024",
response_format: str = "url",
) -> "ImageList":
"""
Creates an image by the input text.
Parameters
----------
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
n (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. Must be between 1 and 10.
size (`str`, *optional*, defaults to `1024*1024`):
The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
response_format (`str`, *optional*, defaults to `url`):
The format in which the generated images are returned. Must be one of url or b64_json.
Returns
-------
ImageList
A list of image objects.
"""
url = f"{self._base_url}/v1/images/generations"
request_body = {
"model": self._model_uid,
"prompt": prompt,
"n": n,
"size": size,
"response_format": response_format,
}
response = requests.post(url, json=request_body)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the images, detail: {response.json()['detail']}"
)

response_data = response.json()
return response_data


class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle):
def generate(
self,
Expand Down Expand Up @@ -458,6 +503,8 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
raise ValueError(f"Unrecognized model ability: {desc['model_ability']}")
elif desc["model_type"] == "embedding":
return RESTfulEmbeddingModelHandle(model_uid, self.base_url)
elif desc["model_type"] == "image":
return RESTfulImageModelHandle(model_uid, self.base_url)
else:
raise ValueError(f"Unknown model type:{desc['model_type']}")

Expand Down
1 change: 1 addition & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
XINFERENCE_CACHE_DIR = os.path.join(XINFERENCE_HOME, "cache")
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")

XINFERENCE_DEFAULT_LOCAL_HOST = "127.0.0.1"
XINFERENCE_DEFAULT_DISTRIBUTED_HOST = "0.0.0.0"
Expand Down
21 changes: 21 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ async def _wrapper():

return await self._call_wrapper(_wrapper)

async def text_to_image(
self,
prompt: str,
n: int = 1,
size: str = "1024*1024",
response_format: str = "url",
*args,
**kwargs,
):
if not hasattr(self._model, "text_to_image"):
raise AttributeError(
f"Model {self._model.model_spec} is not for creating image."
)

async def _wrapper():
return getattr(self._model, "text_to_image")(
prompt, n, size, response_format, *args, **kwargs
)

return await self._call_wrapper(_wrapper)

async def next(
self, generator_uid: str
) -> Union["ChatCompletionChunk", "CompletionChunk"]:
Expand Down
40 changes: 39 additions & 1 deletion xinference/core/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing_extensions import NotRequired, TypedDict
from uvicorn import Config, Server

from ..types import ChatCompletion, Completion, Embedding
from ..types import ChatCompletion, Completion, Embedding, ImageList
from .supervisor import SupervisorActor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -170,6 +170,15 @@ class Config:
}


class TextToImageRequest(BaseModel):
model: str
prompt: Union[str, List[str]] = Field(description="The input to embed.")
n: Optional[int] = 1
response_format: Optional[str] = "url"
size: Optional[str] = "1024*1024"
user: Optional[str] = None


class ChatCompletionRequestMessage(TypedDict):
role: Literal["assistant", "user", "system"]
content: str
Expand Down Expand Up @@ -267,6 +276,12 @@ def serve(self):
methods=["POST"],
response_model=Embedding,
)
self._router.add_api_route(
"/v1/images/generations",
self.create_images,
methods=["POST"],
response_model=ImageList,
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
Expand Down Expand Up @@ -553,6 +568,29 @@ async def create_embedding(self, request: CreateEmbeddingRequest):
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest):
model_uid = request.model
try:
model = await self._supervisor_ref.get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

try:
image_list = await model.text_to_image(
request.prompt, request.n, request.size, request.response_format
)
return image_list
except RuntimeError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def create_chat_completion(
self,
request: Request,
Expand Down
3 changes: 3 additions & 0 deletions xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def create_model_instance(
**kwargs,
) -> Tuple[Any, ModelDescription]:
from .embedding.core import create_embedding_model_instance
from .image.core import create_image_model_instance
from .llm.core import create_llm_model_instance

if model_type == "LLM":
Expand All @@ -51,5 +52,7 @@ def create_model_instance(
# embedding model doesn't accept trust_remote_code
kwargs.pop("trust_remote_code", None)
return create_embedding_model_instance(model_uid, model_name, **kwargs)
elif model_type == "image":
return create_image_model_instance(model_uid, model_name, **kwargs)
else:
raise ValueError(f"Unsupported model type: {model_type}.")
26 changes: 26 additions & 0 deletions xinference/model/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import json
import os

from .core import ImageModelFamilyV1

_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
BUILTIN_IMAGE_MODELS = dict(
(spec["model_name"], ImageModelFamilyV1(**spec))
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
)
del _model_spec_json
Loading

0 comments on commit 1fbcb80

Please sign in to comment.