From 948b99acae66f0c807d0bfbe706df98378e679ad Mon Sep 17 00:00:00 2001 From: Chengjie Li <109656400+ChengjieLi28@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:53:38 +0800 Subject: [PATCH] FEAT: Dynamic batching for the state-of-the-art FLUX.1 `text_to_image` interface (#2380) --- .github/workflows/python.yaml | 1 + .../user_guide/continuous_batching.po | 113 ++-- doc/source/user_guide/continuous_batching.rst | 38 +- xinference/constants.py | 4 + xinference/core/model.py | 101 +++- xinference/core/scheduler.py | 9 +- xinference/core/utils.py | 9 + xinference/model/image/scheduler/__init__.py | 13 + xinference/model/image/scheduler/flux.py | 533 ++++++++++++++++++ .../model/image/stable_diffusion/core.py | 37 +- xinference/model/image/utils.py | 42 +- 11 files changed, 800 insertions(+), 100 deletions(-) create mode 100644 xinference/model/image/scheduler/__init__.py create mode 100644 xinference/model/image/scheduler/flux.py diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 82308d8ed6..3fedc5eea5 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -174,6 +174,7 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pip install -U "ormsgpack" ${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y opencc ${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper" + ${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py && \ diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/continuous_batching.po b/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/continuous_batching.po index 4da5c72b63..4505a7fa2a 100644 --- a/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/continuous_batching.po +++ b/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/continuous_batching.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: Xinference \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-09-06 14:26+0800\n" +"POT-Creation-Date: 2024-10-17 18:49+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -18,8 +18,8 @@ msgstr "" "Generated-By: Babel 2.11.0\n" #: ../../source/user_guide/continuous_batching.rst:5 -msgid "Continuous Batching (experimental)" -msgstr "连续批处理(实验性质)" +msgid "Continuous Batching" +msgstr "连续批处理" #: ../../source/user_guide/continuous_batching.rst:7 msgid "" @@ -35,11 +35,15 @@ msgstr "" msgid "Usage" msgstr "使用方式" -#: ../../source/user_guide/continuous_batching.rst:12 +#: ../../source/user_guide/continuous_batching.rst:14 +msgid "LLM" +msgstr "大语言模型" + +#: ../../source/user_guide/continuous_batching.rst:15 msgid "Currently, this feature can be enabled under the following conditions:" msgstr "当前,此功能在满足以下条件时开启:" -#: ../../source/user_guide/continuous_batching.rst:14 +#: ../../source/user_guide/continuous_batching.rst:17 msgid "" "First, set the environment variable " "``XINFERENCE_TRANSFORMERS_ENABLE_BATCHING`` to ``1`` when starting " @@ -48,13 +52,22 @@ msgstr "" "首先,启动 Xinference 时需要将环境变量 ``XINFERENCE_TRANSFORMERS_ENABLE_" "BATCHING`` 置为 ``1`` 。" -#: ../../source/user_guide/continuous_batching.rst:21 +#: ../../source/user_guide/continuous_batching.rst:25 +msgid "" +"Since ``v0.16.0``, this feature is turned on by default and is no longer " +"required to set the ``XINFERENCE_TRANSFORMERS_ENABLE_BATCHING`` " +"environment variable. This environment variable has been removed." +msgstr "" +"自 ``v0.16.0`` 开始,此功能默认开启,不再需要设置 ``XINFERENCE_TRANSFORMERS_ENABLE_BATCHING`` 环境变量," +"且该环境变量已被移除。" + +#: ../../source/user_guide/continuous_batching.rst:30 msgid "" "Then, ensure that the ``transformers`` engine is selected when launching " "the model. For example:" msgstr "然后,启动 LLM 模型时选择 ``transformers`` 推理引擎。例如:" -#: ../../source/user_guide/continuous_batching.rst:57 +#: ../../source/user_guide/continuous_batching.rst:66 msgid "" "Once this feature is enabled, all requests for LLMs will be managed by " "continuous batching, and the average throughput of requests made to a " @@ -64,54 +77,92 @@ msgstr "" "一旦此功能开启,LLM 模型的所有接口将被此功能接管。所有接口的使用方式没有" "任何变化。" -#: ../../source/user_guide/continuous_batching.rst:63 +#: ../../source/user_guide/continuous_batching.rst:71 +msgid "Image Model" +msgstr "图像模型" + +#: ../../source/user_guide/continuous_batching.rst:72 +msgid "" +"Currently, for image models, only the ``text_to_image`` interface is " +"supported for ``FLUX.1`` series models." +msgstr "" +"当前只有 ``FLUX.1`` 系列模型的 ``text_to_image`` (文生图)接口支持此功能。" + +#: ../../source/user_guide/continuous_batching.rst:74 +msgid "" +"Enabling this feature requires setting the environment variable " +"``XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE``, which indicates the ``size`` " +"of the generated images." +msgstr "" +"图像模型开启此功能需要在启动 xinference 时指定 ``XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE`` 环境变量," +"表示生成图片的大小。" + +#: ../../source/user_guide/continuous_batching.rst:76 +msgid "For example, starting xinference like this:" +msgstr "" +"例如,像这样启动 xinference:" + +#: ../../source/user_guide/continuous_batching.rst:83 +msgid "" +"Then just use the ``text_to_image`` interface as before, and nothing else" +" needs to be changed." +msgstr "" +"接下来正常使用 ``text_to_image`` 接口即可,其他什么都不需要改变。" + +#: ../../source/user_guide/continuous_batching.rst:86 msgid "Abort your request" msgstr "中止请求" -#: ../../source/user_guide/continuous_batching.rst:64 +#: ../../source/user_guide/continuous_batching.rst:87 msgid "In this mode, you can abort requests that are in the process of inference." msgstr "此功能中,你可以优雅地中止正在推理中的请求。" -#: ../../source/user_guide/continuous_batching.rst:66 +#: ../../source/user_guide/continuous_batching.rst:89 msgid "First, add ``request_id`` option in ``generate_config``. For example:" msgstr "首先,在推理请求的 ``generate_config`` 中指定 ``request_id`` 选项。例如:" -#: ../../source/user_guide/continuous_batching.rst:75 +#: ../../source/user_guide/continuous_batching.rst:98 msgid "" "Then, abort the request using the ``request_id`` you have set. For " "example:" msgstr "接着,带着你指定的 ``request_id`` 去中止该请求。例如:" -#: ../../source/user_guide/continuous_batching.rst:83 +#: ../../source/user_guide/continuous_batching.rst:106 msgid "" "Note that if your request has already finished, aborting the request will" -" be a no-op." +" be a no-op. Image models also support this feature." msgstr "注意,如果你的请求已经结束,那么此操作将什么都不做。" -#: ../../source/user_guide/continuous_batching.rst:86 +#: ../../source/user_guide/continuous_batching.rst:110 msgid "Note" msgstr "注意事项" -#: ../../source/user_guide/continuous_batching.rst:88 +#: ../../source/user_guide/continuous_batching.rst:112 msgid "" -"Currently, this feature only supports the ``generate``, ``chat`` and " -"``vision`` tasks for ``LLM`` models. The ``tool call`` tasks are not " -"supported." +"Currently, for ``LLM`` models, this feature only supports the " +"``generate``, ``chat``, ``tool call`` and ``vision`` tasks." msgstr "" -"当前,此功能仅支持 LLM 模型的 ``generate``, ``chat`` 和 ``vision`` (多" -"模态) 功能。``tool call`` (工具调用)暂时不支持。" +"当前,此功能仅支持 LLM 模型的 ``generate``, ``chat``, ``tool call`` (工具调用)和 ``vision`` (多" +"模态) 功能。" -#: ../../source/user_guide/continuous_batching.rst:90 +#: ../../source/user_guide/continuous_batching.rst:114 +msgid "" +"Currently, for ``image`` models, this feature only supports the " +"``text_to_image`` tasks. Only ``FLUX.1`` series models are supported." +msgstr "" +"当前,对于图像模型,仅支持 `FLUX.1`` 系列模型的 ``text_to_image`` (文生图)功能。" + +#: ../../source/user_guide/continuous_batching.rst:116 msgid "" "For ``vision`` tasks, currently only ``qwen-vl-chat``, ``cogvlm2``, " "``glm-4v`` and ``MiniCPM-V-2.6`` (only for image tasks) models are " "supported. More models will be supported in the future. Please let us " "know your requirements." msgstr "" -"对于多模态任务,当前支持 ``qwen-vl-chat`` ,``cogvlm2``, ``glm-4v`` 和 ``MiniCPM-V-2.6`` (仅对于图像任务)" -"模型。未来将加入更多模型,敬请期待。" +"对于多模态任务,当前支持 ``qwen-vl-chat`` ,``cogvlm2``, ``glm-4v`` 和 `" +"`MiniCPM-V-2.6`` (仅对于图像任务)模型。未来将加入更多模型,敬请期待。" -#: ../../source/user_guide/continuous_batching.rst:92 +#: ../../source/user_guide/continuous_batching.rst:118 msgid "" "If using GPU inference, this method will consume more GPU memory. Please " "be cautious when increasing the number of concurrent requests to the same" @@ -123,17 +174,3 @@ msgstr "" "请求量。``launch_model`` 接口提供可选参数 ``max_num_seqs`` 用于调整并发度" ",默认值为 ``16`` 。" -#: ../../source/user_guide/continuous_batching.rst:95 -msgid "" -"This feature is still in the experimental stage, and we welcome your " -"active feedback on any issues." -msgstr "此功能仍处于实验阶段,欢迎反馈任何问题。" - -#: ../../source/user_guide/continuous_batching.rst:97 -msgid "" -"After a period of testing, this method will remain enabled by default, " -"and the original inference method will be deprecated." -msgstr "" -"一段时间的测试之后,此功能将代替原来的 transformers 推理逻辑成为默认行为" -"。原来的推理逻辑将被摒弃。" - diff --git a/doc/source/user_guide/continuous_batching.rst b/doc/source/user_guide/continuous_batching.rst index 1237e7e235..e720288c57 100644 --- a/doc/source/user_guide/continuous_batching.rst +++ b/doc/source/user_guide/continuous_batching.rst @@ -1,14 +1,17 @@ .. _user_guide_continuous_batching: -================================== -Continuous Batching (experimental) -================================== +=================== +Continuous Batching +=================== Continuous batching, as a means to improve throughput during model serving, has already been implemented in inference engines like ``VLLM``. Xinference aims to provide this optimization capability when using the transformers engine as well. Usage ===== + +LLM +--- Currently, this feature can be enabled under the following conditions: * First, set the environment variable ``XINFERENCE_TRANSFORMERS_ENABLE_BATCHING`` to ``1`` when starting xinference. For example: @@ -18,6 +21,12 @@ Currently, this feature can be enabled under the following conditions: XINFERENCE_TRANSFORMERS_ENABLE_BATCHING=1 xinference-local --log-level debug +.. note:: + Since ``v0.16.0``, this feature is turned on by default and + is no longer required to set the ``XINFERENCE_TRANSFORMERS_ENABLE_BATCHING`` environment variable. + This environment variable has been removed. + + * Then, ensure that the ``transformers`` engine is selected when launching the model. For example: .. tabs:: @@ -58,6 +67,20 @@ Once this feature is enabled, all requests for LLMs will be managed by continuou and the average throughput of requests made to a single model will increase. The usage of the LLM interface remains exactly the same as before, with no differences. +Image Model +----------- +Currently, for image models, only the ``text_to_image`` interface is supported for ``FLUX.1`` series models. + +Enabling this feature requires setting the environment variable ``XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE``, which indicates the ``size`` of the generated images. + +For example, starting xinference like this: + +.. code-block:: + + XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE=1024*1024 xinference-local --log-level debug + + +Then just use the ``text_to_image`` interface as before, and nothing else needs to be changed. Abort your request ================== @@ -81,17 +104,16 @@ In this mode, you can abort requests that are in the process of inference. client.abort_request("", "") Note that if your request has already finished, aborting the request will be a no-op. +Image models also support this feature. Note ==== -* Currently, this feature only supports the ``generate``, ``chat`` and ``vision`` tasks for ``LLM`` models. The ``tool call`` tasks are not supported. +* Currently, for ``LLM`` models, this feature only supports the ``generate``, ``chat``, ``tool call`` and ``vision`` tasks. + +* Currently, for ``image`` models, this feature only supports the ``text_to_image`` tasks. Only ``FLUX.1`` series models are supported. * For ``vision`` tasks, currently only ``qwen-vl-chat``, ``cogvlm2``, ``glm-4v`` and ``MiniCPM-V-2.6`` (only for image tasks) models are supported. More models will be supported in the future. Please let us know your requirements. * If using GPU inference, this method will consume more GPU memory. Please be cautious when increasing the number of concurrent requests to the same model. The ``launch_model`` interface provides the ``max_num_seqs`` parameter to adjust the concurrency level, with a default value of ``16``. - -* This feature is still in the experimental stage, and we welcome your active feedback on any issues. - -* After a period of testing, this method will remain enabled by default, and the original inference method will be deprecated. diff --git a/xinference/constants.py b/xinference/constants.py index ae75f2c9f0..66e9983a93 100644 --- a/xinference/constants.py +++ b/xinference/constants.py @@ -28,6 +28,7 @@ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK" XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS" XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS" +XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE = "XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE" def get_xinference_home() -> str: @@ -82,3 +83,6 @@ def get_xinference_home() -> str: XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int( os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3) ) +XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get( + XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None +) diff --git a/xinference/core/model.py b/xinference/core/model.py index 170a7b5edc..206adc25d9 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -41,6 +41,8 @@ import sse_starlette.sse import xoscar as xo +from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE + if TYPE_CHECKING: from .progress_tracker import ProgressTrackerActor from .worker import WorkerActor @@ -72,6 +74,8 @@ class _OutOfMemoryError(Exception): "MiniCPM-V-2.6", ] +XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"] + def request_limit(fn): """ @@ -151,6 +155,16 @@ async def __pre_destroy__(self): f"Destroy scheduler actor failed, address: {self.address}, error: {e}" ) + if self.allow_batching_for_text_to_image(): + try: + assert self._text_to_image_scheduler_ref is not None + await xo.destroy_actor(self._text_to_image_scheduler_ref) + del self._text_to_image_scheduler_ref + except Exception as e: + logger.debug( + f"Destroy text_to_image scheduler actor failed, address: {self.address}, error: {e}" + ) + if hasattr(self._model, "stop") and callable(self._model.stop): self._model.stop() @@ -218,6 +232,7 @@ def __init__( self._loop: Optional[asyncio.AbstractEventLoop] = None self._scheduler_ref = None + self._text_to_image_scheduler_ref = None async def __post_create__(self): self._loop = asyncio.get_running_loop() @@ -231,6 +246,15 @@ async def __post_create__(self): uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id), ) + if self.allow_batching_for_text_to_image(): + from ..model.image.scheduler.flux import FluxBatchSchedulerActor + + self._text_to_image_scheduler_ref = await xo.create_actor( + FluxBatchSchedulerActor, + address=self.address, + uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()), + ) + async def _record_completion_metrics( self, duration, completion_tokens, prompt_tokens ): @@ -327,6 +351,26 @@ def allow_batching(self) -> bool: return False return condition + def allow_batching_for_text_to_image(self) -> bool: + from ..model.image.stable_diffusion.core import DiffusionModel + + condition = XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE is not None and isinstance( + self._model, DiffusionModel + ) + + if condition: + model_name = self._model._model_spec.model_name # type: ignore + if model_name in XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS: + return True + else: + logger.warning( + f"Currently for image models with text_to_image ability, " + f"xinference only supports {', '.join(XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS)} for batching. " + f"Your model {model_name} is disqualified." + ) + return False + return condition + async def load(self): self._model.load() if self.allow_batching(): @@ -334,6 +378,11 @@ async def load(self): logger.debug( f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}" ) + if self.allow_batching_for_text_to_image(): + await self._text_to_image_scheduler_ref.set_model(self._model) + logger.debug( + f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}" + ) def model_uid(self): return ( @@ -613,12 +662,16 @@ async def chat(self, messages: List[Dict], *args, **kwargs): ) async def abort_request(self, request_id: str) -> str: - from .scheduler import AbortRequestMessage + from .utils import AbortRequestMessage if self.allow_batching(): if self._scheduler_ref is None: return AbortRequestMessage.NOT_FOUND.name return await self._scheduler_ref.abort_request(request_id) + elif self.allow_batching_for_text_to_image(): + if self._text_to_image_scheduler_ref is None: + return AbortRequestMessage.NOT_FOUND.name + return await self._text_to_image_scheduler_ref.abort_request(request_id) return AbortRequestMessage.NO_OP.name @request_limit @@ -743,6 +796,22 @@ async def speech( f"Model {self._model.model_spec} is not for creating speech." ) + async def handle_image_batching_request(self, unique_id, *args, **kwargs): + size = args[2] + if XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE != size: + raise RuntimeError( + f"The image size: {size} of text_to_image for batching " + f"must be the same as the environment variable: {XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE} you set." + ) + assert self._loop is not None + future = ConcurrentFuture() + await self._text_to_image_scheduler_ref.add_request( + unique_id, future, *args, **kwargs + ) + fut = asyncio.wrap_future(future, loop=self._loop) + result = await fut + return await asyncio.to_thread(json_dumps, result) + @request_limit @log_async(logger=logger) async def text_to_image( @@ -755,19 +824,25 @@ async def text_to_image( **kwargs, ): if hasattr(self._model, "text_to_image"): - progressor = kwargs["progressor"] = await self._get_progressor( - kwargs.pop("request_id", None) - ) - with progressor: - return await self._call_wrapper_json( - self._model.text_to_image, - prompt, - n, - size, - response_format, - *args, - **kwargs, + if self.allow_batching_for_text_to_image(): + unique_id = kwargs.pop("request_id", None) + return await self.handle_image_batching_request( + unique_id, prompt, n, size, response_format, *args, **kwargs ) + else: + progressor = kwargs["progressor"] = await self._get_progressor( + kwargs.pop("request_id", None) + ) + with progressor: + return await self._call_wrapper_json( + self._model.text_to_image, + prompt, + n, + size, + response_format, + *args, + **kwargs, + ) raise AttributeError( f"Model {self._model.model_spec} is not for creating image." ) diff --git a/xinference/core/scheduler.py b/xinference/core/scheduler.py index 657f668d0e..1b91d62e27 100644 --- a/xinference/core/scheduler.py +++ b/xinference/core/scheduler.py @@ -17,11 +17,12 @@ import logging import uuid from collections import deque -from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union import xoscar as xo +from .utils import AbortRequestMessage + logger = logging.getLogger(__name__) XINFERENCE_STREAMING_DONE_FLAG = "" @@ -30,12 +31,6 @@ XINFERENCE_NON_STREAMING_ABORT_FLAG = "" -class AbortRequestMessage(Enum): - NOT_FOUND = 1 - DONE = 2 - NO_OP = 3 - - class InferenceRequest: def __init__( self, diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 9f359fa315..47aa45b27d 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -16,6 +16,7 @@ import random import string import uuid +from enum import Enum from typing import Dict, Generator, List, Optional, Tuple, Union import orjson @@ -27,6 +28,12 @@ logger = logging.getLogger(__name__) +class AbortRequestMessage(Enum): + NOT_FOUND = 1 + DONE = 2 + NO_OP = 3 + + def truncate_log_arg(arg) -> str: s = str(arg) if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH: @@ -51,6 +58,8 @@ async def wrapped(*args, **kwargs): request_id_str = kwargs.get("request_id", "") if not request_id_str: request_id_str = uuid.uuid1() + if func_name == "text_to_image": + kwargs["request_id"] = request_id_str request_id_str = f"[request {request_id_str}]" formatted_args = ",".join(map(truncate_log_arg, args)) formatted_kwargs = ",".join( diff --git a/xinference/model/image/scheduler/__init__.py b/xinference/model/image/scheduler/__init__.py new file mode 100644 index 0000000000..09138b5b2a --- /dev/null +++ b/xinference/model/image/scheduler/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2024 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/image/scheduler/flux.py b/xinference/model/image/scheduler/flux.py new file mode 100644 index 0000000000..174acb82e3 --- /dev/null +++ b/xinference/model/image/scheduler/flux.py @@ -0,0 +1,533 @@ +# Copyright 2022-2024 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import re +import typing +from collections import deque +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import xoscar as xo + +from ..utils import handle_image_result + +if TYPE_CHECKING: + from ..stable_diffusion.core import DiffusionModel + + +logger = logging.getLogger(__name__) +DEFAULT_MAX_SEQUENCE_LENGTH = 512 + + +class Text2ImageRequest: + def __init__( + self, + unique_id, + future, + prompt: str, + n: int, + size: str, + response_format: str, + *args, + **kwargs, + ): + self._unique_id = unique_id + self.future = future + self._prompt = prompt + self._n = n + self._size = size + self._response_format = response_format + self._args = args + self._kwargs = kwargs + self._width = -1 + self._height = -1 + self._generate_kwargs: Dict[str, Any] = {} + self._set_width_and_height() + self.is_encode = True + self.scheduler = None + self.done_steps = 0 + self.total_steps = 0 + self.static_tensors: Dict[str, torch.Tensor] = {} + self.timesteps = None + self.dtype = None + self.output = None + self.error_msg: Optional[str] = None + self.aborted = False + + def _set_width_and_height(self): + self._width, self._height = map(int, re.split(r"[^\d]+", self._size)) + + def set_generate_kwargs(self, generate_kwargs: Dict): + self._generate_kwargs = {k: v for k, v in generate_kwargs.items()} + + @property + def prompt(self): + return self._prompt + + @property + def n(self): + return self._n + + @property + def size(self): + return self._size + + @property + def response_format(self): + return self._response_format + + @property + def kwargs(self): + return self._kwargs + + @property + def width(self): + return self._width + + @property + def height(self): + return self._height + + @property + def generate_kwargs(self): + return self._generate_kwargs + + @property + def request_id(self): + return self._unique_id + + +class FluxBatchSchedulerActor(xo.StatelessActor): + @classmethod + def gen_uid(cls, model_uid: str): + return f"{model_uid}-scheduler-actor" + + def __init__(self): + from ....device_utils import get_available_device + + super().__init__() + self._waiting_queue: deque[Text2ImageRequest] = deque() # type: ignore + self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore + self._model = None + self._available_device = get_available_device() + self._id_to_req: Dict[str, Text2ImageRequest] = {} + + def set_model(self, model): + """ + Must use `set_model`. Otherwise, the model will be copied once. + """ + self._model = model + + async def __post_create__(self): + from ....isolation import Isolation + + self._isolation = Isolation( + asyncio.new_event_loop(), threaded=True, daemon=True + ) + self._isolation.start() + asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop) + + async def __pre_destroy__(self): + try: + assert self._isolation is not None + self._isolation.stop() + del self._isolation + except Exception as e: + logger.debug( + f"Destroy scheduler actor failed, address: {self.address}, error: {e}" + ) + + async def add_request(self, unique_id: str, future, *args, **kwargs): + req = Text2ImageRequest(unique_id, future, *args, **kwargs) + rid = req.request_id + if rid is not None: + if rid in self._id_to_req: + raise KeyError(f"Request id: {rid} has already existed!") + self._id_to_req[rid] = req + self._waiting_queue.append(req) + + async def abort_request(self, req_id: str) -> str: + """ + Abort a request. + Abort a submitted request. If the request is finished or not found, this method will be a no-op. + """ + from ....core.utils import AbortRequestMessage + + if req_id not in self._id_to_req: + logger.info(f"Request id: {req_id} not found. No-op for xinference.") + return AbortRequestMessage.NOT_FOUND.name + else: + self._id_to_req[req_id].aborted = True + logger.info(f"Request id: {req_id} found to be aborted.") + return AbortRequestMessage.DONE.name + + def _handle_request( + self, + ) -> Optional[Tuple[List[Text2ImageRequest], List[Text2ImageRequest]]]: + """ + Every request may generate `n>=1` images. + Here we need to decide whether to wait or not based on the value of `n` of each request. + """ + if self._model is None: + return None + max_num_images = self._model.get_max_num_images_for_batching() + cur_num_images = 0 + abort_list: List[Text2ImageRequest] = [] + # currently, FCFS strategy + running_list: List[Text2ImageRequest] = [] + while len(self._running_queue) > 0: + req = self._running_queue.popleft() + if req.aborted: + abort_list.append(req) + else: + running_list.append(req) + cur_num_images += req.n + + # Remove all the aborted requests in the waiting queue + waiting_tmp_list: List[Text2ImageRequest] = [] + while len(self._waiting_queue) > 0: + req = self._waiting_queue.popleft() + if req.aborted: + abort_list.append(req) + else: + waiting_tmp_list.append(req) + self._waiting_queue.extend(waiting_tmp_list) + + waiting_list: List[Text2ImageRequest] = [] + while len(self._waiting_queue) > 0: + req = self._waiting_queue[0] + if req.n + cur_num_images <= max_num_images: + waiting_list.append(self._waiting_queue.popleft()) + cur_num_images += req.n + else: + logger.warning( + f"Current queue is full, with an upper limit of max_num_images: {max_num_images}. " + f"Requests will continue to wait." + ) + break + + return waiting_list + running_list, abort_list + + @staticmethod + def _empty_cache(): + from ....device_utils import empty_cache + + empty_cache() + + async def step(self): + res = self._handle_request() + if res is None: + return + req_list, abort_list = res + # handle abort + if abort_list: + for r in abort_list: + r.future.set_exception( + RuntimeError( + f"Request: {r.request_id} has been cancelled by another `abort_request` request." + ) + ) + self._id_to_req.pop(r.request_id, None) + if not req_list: + return + _batch_text_to_image(self._model, req_list, self._available_device) + # handle results + for r in req_list: + if r.error_msg is not None: + r.future.set_exception(ValueError(r.error_msg)) + self._id_to_req.pop(r.request_id, None) + continue + if r.output is not None: + r.future.set_result( + handle_image_result(r.response_format, r.output.images) + ) + self._id_to_req.pop(r.request_id, None) + else: + self._running_queue.append(r) + self._empty_cache() + + async def run(self): + try: + while True: + # wait 10ms + await asyncio.sleep(0.01) + await self.step() + except Exception as e: + logger.exception( + f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}" + ) + + +def _cat_tensors(infos: List[Dict]) -> Dict: + keys = infos[0].keys() + res = {} + for k in keys: + tmp = [info[k] for info in infos] + res[k] = torch.cat(tmp) + return res + + +@typing.no_type_check +@torch.inference_mode() +def _batch_text_to_image_internal( + model_cls: "DiffusionModel", + req_list: List[Text2ImageRequest], + available_device: str, +): + from diffusers.pipelines.flux.pipeline_flux import ( + calculate_shift, + retrieve_timesteps, + ) + from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput + from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, + ) + + device = model_cls._model._execution_device + height, width = req_list[0].height, req_list[0].width + cur_batch_max_sequence_length = [ + r.generate_kwargs.get("max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH) + for r in req_list + if not r.is_encode + ] + for r in req_list: + if r.is_encode: + generate_kwargs = model_cls._model_spec.default_generate_config.copy() + generate_kwargs.update({k: v for k, v in r.kwargs.items() if v is not None}) + model_cls._filter_kwargs(model_cls._model, generate_kwargs) + r.set_generate_kwargs(generate_kwargs) + + # check max_sequence_length + max_sequence_length = r.generate_kwargs.get( + "max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH + ) + if ( + cur_batch_max_sequence_length + and max_sequence_length != cur_batch_max_sequence_length[0] + ): + r.is_encode = False + r.error_msg = ( + f"The max_sequence_length of the current request: {max_sequence_length} is " + f"different from the setting in the running batch: {cur_batch_max_sequence_length[0]}, " + f"please be consistent." + ) + continue + + num_images_per_prompt = r.n + callback_on_step_end_tensor_inputs = r.generate_kwargs.get( + "callback_on_step_end_tensor_inputs", ["latents"] + ) + num_inference_steps = r.generate_kwargs.get("num_inference_steps", 28) + guidance_scale = r.generate_kwargs.get("guidance_scale", 7.0) + generator = None + seed = r.generate_kwargs.get("seed", None) + if seed is not None: + generator = torch.Generator(device=available_device) # type: ignore + if seed != -1: + generator = generator.manual_seed(seed) + latents = None + timesteps = None + + # Each request must build its own scheduler instance, + # otherwise the mixing of variables at `scheduler.STEP` will result in an error. + r.scheduler = FlowMatchEulerDiscreteScheduler( + model_cls._model.scheduler.config.num_train_timesteps, + model_cls._model.scheduler.config.shift, + model_cls._model.scheduler.config.use_dynamic_shifting, + model_cls._model.scheduler.config.base_shift, + model_cls._model.scheduler.config.max_shift, + model_cls._model.scheduler.config.base_image_seq_len, + model_cls._model.scheduler.config.max_image_seq_len, + ) + + # check inputs + model_cls._model.check_inputs( + r.prompt, + None, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + # handle prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = model_cls._model.encode_prompt( + prompt=r.prompt, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=None, + ) + + # Prepare latent variables + num_channels_latents = model_cls._model.transformer.config.in_channels // 4 + latents, latent_image_ids = model_cls._model.prepare_latents( + num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + + mu = calculate_shift( + image_seq_len, + r.scheduler.config["base_image_seq_len"], + r.scheduler.config["max_image_seq_len"], + r.scheduler.config["base_shift"], + r.scheduler.config["max_shift"], + ) + timesteps, num_inference_steps = retrieve_timesteps( + r.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + + # handle guidance + if model_cls._model.transformer.config.guidance_embeds: + guidance = torch.full( + [1], guidance_scale, device=device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + r.static_tensors["latents"] = latents + r.static_tensors["guidance"] = guidance + r.static_tensors["pooled_prompt_embeds"] = pooled_prompt_embeds + r.static_tensors["prompt_embeds"] = prompt_embeds + r.static_tensors["text_ids"] = text_ids + r.static_tensors["latent_image_ids"] = latent_image_ids + r.timesteps = timesteps + r.dtype = latents.dtype + r.total_steps = len(timesteps) + r.is_encode = False + + running_req_list = [r for r in req_list if r.error_msg is None] + static_tensors = _cat_tensors([r.static_tensors for r in running_req_list]) + + # Do a step + timestep_tmp = [] + for r in running_req_list: + timestep_tmp.append(r.timesteps[r.done_steps].expand(r.n).to(r.dtype)) + r.done_steps += 1 + timestep = torch.cat(timestep_tmp) + noise_pred = model_cls._model.transformer( + hidden_states=static_tensors["latents"], + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=static_tensors["guidance"], + pooled_projections=static_tensors["pooled_prompt_embeds"], + encoder_hidden_states=static_tensors["prompt_embeds"], + txt_ids=static_tensors["text_ids"], + img_ids=static_tensors["latent_image_ids"], + joint_attention_kwargs=None, + return_dict=False, + )[0] + + # update latents + start_idx = 0 + for r in running_req_list: + n = r.n + # handle diffusion scheduler step + _noise_pred = noise_pred[start_idx : start_idx + n, ::] + _timestep = timestep[start_idx] + latents_out = r.scheduler.step( + _noise_pred, _timestep, r.static_tensors["latents"], return_dict=False + )[0] + r.static_tensors["latents"] = latents_out + start_idx += n + + logger.info( + f"Request {r.request_id} has done {r.done_steps} / {r.total_steps} steps." + ) + + # process result + if r.done_steps == r.total_steps: + output_type = r.generate_kwargs.get("output_type", "pil") + _latents = r.static_tensors["latents"] + if output_type == "latent": + image = _latents + else: + _latents = model_cls._model._unpack_latents( + _latents, height, width, model_cls._model.vae_scale_factor + ) + _latents = ( + _latents / model_cls._model.vae.config.scaling_factor + ) + model_cls._model.vae.config.shift_factor + image = model_cls._model.vae.decode(_latents, return_dict=False)[0] + image = model_cls._model.image_processor.postprocess( + image, output_type=output_type + ) + + is_padded = r.generate_kwargs.get("is_padded", None) + origin_size = r.generate_kwargs.get("origin_size", None) + + if is_padded and origin_size: + new_images = [] + x, y = origin_size + for img in image: + new_images.append(img.crop((0, 0, x, y))) + image = new_images + + r.output = FluxPipelineOutput(images=image) + logger.info( + f"Request {r.request_id} has completed total {r.total_steps} steps." + ) + + +def _batch_text_to_image( + model_cls: "DiffusionModel", + req_list: List[Text2ImageRequest], + available_device: str, +): + from ....core.model import OutOfMemoryError + + try: + _batch_text_to_image_internal(model_cls, req_list, available_device) + except OutOfMemoryError: + logger.exception( + f"Batch text_to_image out of memory. " + f"Xinference will restart the model: {model_cls._model_uid}. " + f"Please be patient for a few moments." + ) + # Just kill the process and let xinference auto-recover the model + os._exit(1) + except Exception as e: + logger.exception(f"Internal error for batch text_to_image: {e}.") + # If internal error happens, just skip all the requests in this batch. + # If not handle here, the client will hang. + for r in req_list: + r.error_msg = str(e) diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index 3fe2c803c1..ae9b6e4bd4 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -12,31 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import contextlib import gc import inspect import itertools import logging -import os import re import sys -import time -import uuid import warnings -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from io import BytesIO from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import PIL.Image import torch from PIL import ImageOps -from ....constants import XINFERENCE_IMAGE_DIR from ....device_utils import get_available_device, move_model_to_available_device -from ....types import Image, ImageList, LoRA +from ....types import LoRA from ..sdapi import SDAPIDiffusionModelMixin +from ..utils import handle_image_result if TYPE_CHECKING: from ....core.progress_tracker import Progressor @@ -297,6 +290,9 @@ def _load_to_device(self, model): if self._kwargs.get("vae_tiling", False): model.enable_vae_tiling() + def get_max_num_images_for_batching(self): + return self._kwargs.get("max_num_images", 16) + @staticmethod def _get_scheduler(model: Any, sampler_name: str): if not sampler_name or sampler_name == "default": @@ -476,28 +472,7 @@ def _call_model( if return_images: return images - if response_format == "url": - os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True) - image_list = [] - with ThreadPoolExecutor() as executor: - for img in images: - path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg") - image_list.append(Image(url=path, b64_json=None)) - executor.submit(img.save, path, "jpeg") - return ImageList(created=int(time.time()), data=image_list) - elif response_format == "b64_json": - - def _gen_base64_image(_img): - buffered = BytesIO() - _img.save(buffered, format="jpeg") - return base64.b64encode(buffered.getvalue()).decode() - - with ThreadPoolExecutor() as executor: - results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore - image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore - return ImageList(created=int(time.time()), data=image_list) - else: - raise ValueError(f"Unsupported response format: {response_format}") + return handle_image_result(response_format, images) @classmethod def _filter_kwargs(cls, model, kwargs: dict): diff --git a/xinference/model/image/utils.py b/xinference/model/image/utils.py index bc4cbc350d..53df69e549 100644 --- a/xinference/model/image/utils.py +++ b/xinference/model/image/utils.py @@ -11,16 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +import base64 +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from io import BytesIO +from typing import TYPE_CHECKING, Optional -from .core import ImageModelFamilyV1 +from ...constants import XINFERENCE_IMAGE_DIR +from ...types import Image, ImageList + +if TYPE_CHECKING: + from .core import ImageModelFamilyV1 def get_model_version( - image_model: ImageModelFamilyV1, controlnet: Optional[ImageModelFamilyV1] + image_model: "ImageModelFamilyV1", controlnet: Optional["ImageModelFamilyV1"] ) -> str: return ( image_model.model_name if controlnet is None else f"{image_model.model_name}--{controlnet.model_name}" ) + + +def handle_image_result(response_format: str, images) -> ImageList: + if response_format == "url": + os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True) + image_list = [] + with ThreadPoolExecutor() as executor: + for img in images: + path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg") + image_list.append(Image(url=path, b64_json=None)) + executor.submit(img.save, path, "jpeg") + return ImageList(created=int(time.time()), data=image_list) + elif response_format == "b64_json": + + def _gen_base64_image(_img): + buffered = BytesIO() + _img.save(buffered, format="jpeg") + return base64.b64encode(buffered.getvalue()).decode() + + with ThreadPoolExecutor() as executor: + results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore + image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore + return ImageList(created=int(time.time()), data=image_list) + else: + raise ValueError(f"Unsupported response format: {response_format}")