Skip to content

Commit

Permalink
FEAT: add rembg flexible model to remove background of image (#1917)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Jul 26, 2024
1 parent d5562f8 commit aa51ff2
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/source/models/builtin/image/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ The following is a list of built-in image models in Xinference:
stable-diffusion-xl-base-1.0

stable-diffusion-xl-inpainting

stable-diffusion-xl-inpainting

8 changes: 4 additions & 4 deletions xinference/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
self.future_or_queue = future_or_queue
# Record error message when this request has error.
# Must set stopped=True when this field is set.
self.error_msg: Optional[str] = None
self.error_msg: Optional[str] = None # type: ignore
# For compatibility. Record some extra parameters for some special cases.
self.extra_kwargs = {}

Expand Down Expand Up @@ -295,11 +295,11 @@ def gen_uid(cls, model_uid: str, replica_id: str):

def __init__(self):
super().__init__()
self._waiting_queue: deque[InferenceRequest] = deque()
self._running_queue: deque[InferenceRequest] = deque()
self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
self._running_queue: deque[InferenceRequest] = deque() # type: ignore
self._model = None
self._id_to_req = {}
self._abort_req_ids: Set[str] = set()
self._abort_req_ids: Set[str] = set() # type: ignore
self._isolation = None

async def __post_create__(self):
Expand Down
1 change: 1 addition & 0 deletions xinference/model/flexible/launchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .image_process_launcher import launcher as image_process
from .transformers_launcher import launcher as transformers
70 changes: 70 additions & 0 deletions xinference/model/flexible/launchers/image_process_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
from io import BytesIO

import PIL.Image
import PIL.ImageOps

from ....types import Image
from ..core import FlexibleModel, FlexibleModelSpec


class ImageRemoveBackgroundModel(FlexibleModel):
def infer(self, **kwargs):
invert = kwargs.get("invert", False)
b64_image: str = kwargs.get("image") # type: ignore
only_mask = kwargs.pop("only_mask", True)
image_format = kwargs.pop("image_format", "PNG")
if not b64_image:
raise ValueError("No image found to remove background")
image = base64.b64decode(b64_image)

try:
from rembg import remove
except ImportError:
error_message = "Failed to import module 'rembg'"
installation_guide = [
"Please make sure 'rembg' is installed. ",
"You can install it by visiting the installation section of the git repo:\n",
"https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

im = PIL.Image.open(BytesIO(image))
om = remove(im, only_mask=only_mask, **kwargs)
if invert:
om = PIL.ImageOps.invert(om)

buffered = BytesIO()
om.save(buffered, format=image_format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return Image(url=None, b64_json=img_str)


def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
task = kwargs.get("task")
device = kwargs.get("device")

if task == "remove_background":
return ImageRemoveBackgroundModel(
model_uid=model_uid,
model_path=model_spec.model_uri, # type: ignore
device=device,
config=kwargs,
)
else:
raise ValueError(f"Unknown Task for image processing: {task}")
2 changes: 1 addition & 1 deletion xinference/model/llm/pytorch/cogvlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
prompt, system_prompt=system_prompt, chat_history=chat_history
)

input_by_model: dict = self._model.build_conversation_input_ids(
input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
self._tokenizer,
query=query,
history=history,
Expand Down

0 comments on commit aa51ff2

Please sign in to comment.