From 5e3f254d48383f37d849dd16db564ad9449e5163 Mon Sep 17 00:00:00 2001 From: hainaweiben <112739514+hainaweiben@users.noreply.github.com> Date: Fri, 12 Jul 2024 17:56:13 +0800 Subject: [PATCH] ENH: Added the parameter 'worker_ip' to the 'register' model. (#1773) Co-authored-by: wuzhaoxin <15667065080@162.com> --- xinference/api/restful_api.py | 4 +- xinference/client/restful/restful_client.py | 12 +- xinference/core/supervisor.py | 73 +++++++-- xinference/core/worker.py | 155 +++++++++++++++++++- xinference/deploy/cmdline.py | 5 + xinference/model/image/custom.py | 2 +- 6 files changed, 226 insertions(+), 25 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 3c70902da2..7d28ffe8b1 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -133,6 +133,7 @@ class SpeechRequest(BaseModel): class RegisterModelRequest(BaseModel): model: str + worker_ip: Optional[str] persist: bool @@ -1639,11 +1640,12 @@ async def query_engines_by_model_name(self, model_name: str) -> JSONResponse: async def register_model(self, model_type: str, request: Request) -> JSONResponse: body = RegisterModelRequest.parse_obj(await request.json()) model = body.model + worker_ip = body.worker_ip persist = body.persist try: await (await self._get_supervisor_ref()).register_model( - model_type, model, persist + model_type, model, persist, worker_ip ) except ValueError as re: logger.error(re, exc_info=True) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index ec5b2dac04..54be488748 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -1101,7 +1101,13 @@ def describe_model(self, model_uid: str): ) return response.json() - def register_model(self, model_type: str, model: str, persist: bool): + def register_model( + self, + model_type: str, + model: str, + persist: bool, + worker_ip: Optional[str] = None, + ): """ Register a custom model. @@ -1111,6 +1117,8 @@ def register_model(self, model_type: str, model: str, persist: bool): The type of model. model: str The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html) + worker_ip: Optional[str] + The IP address of the worker on which the model is running. persist: bool @@ -1120,7 +1128,7 @@ def register_model(self, model_type: str, model: str, persist: bool): Report failure to register the custom model. Provide details of failure through error message. """ url = f"{self.base_url}/v1/model_registrations/{model_type}" - request_body = {"model": model, "persist": persist} + request_body = {"model": model, "worker_ip": worker_ip, "persist": persist} response = requests.post(url, json=request_body, headers=self._headers) if response.status_code != 200: raise RuntimeError( diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index db4105c8ee..873a0b6ab8 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -513,10 +513,15 @@ def sort_helper(item): assert isinstance(item["model_name"], str) return item.get("model_name").lower() + ret = [] + if not self.is_local_deployment(): + workers = list(self._worker_address_to_worker.values()) + for worker in workers: + ret.extend(await worker.list_model_registrations(model_type, detailed)) + if model_type == "LLM": from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families - ret = [] for family in BUILTIN_LLM_FAMILIES: if detailed: ret.append(await self._to_llm_reg(family, True)) @@ -535,7 +540,6 @@ def sort_helper(item): from ..model.embedding import BUILTIN_EMBEDDING_MODELS from ..model.embedding.custom import get_user_defined_embeddings - ret = [] for model_name, family in BUILTIN_EMBEDDING_MODELS.items(): if detailed: ret.append( @@ -560,7 +564,6 @@ def sort_helper(item): from ..model.image import BUILTIN_IMAGE_MODELS from ..model.image.custom import get_user_defined_images - ret = [] for model_name, family in BUILTIN_IMAGE_MODELS.items(): if detailed: ret.append(await self._to_image_model_reg(family, is_builtin=True)) @@ -583,7 +586,6 @@ def sort_helper(item): from ..model.audio import BUILTIN_AUDIO_MODELS from ..model.audio.custom import get_user_defined_audios - ret = [] for model_name, family in BUILTIN_AUDIO_MODELS.items(): if detailed: ret.append(await self._to_audio_model_reg(family, is_builtin=True)) @@ -606,7 +608,6 @@ def sort_helper(item): from ..model.rerank import BUILTIN_RERANK_MODELS from ..model.rerank.custom import get_user_defined_reranks - ret = [] for model_name, family in BUILTIN_RERANK_MODELS.items(): if detailed: ret.append(await self._to_rerank_model_reg(family, is_builtin=True)) @@ -646,7 +647,15 @@ def sort_helper(item): raise ValueError(f"Unsupported model type: {model_type}") @log_sync(logger=logger) - def get_model_registration(self, model_type: str, model_name: str) -> Any: + async def get_model_registration(self, model_type: str, model_name: str) -> Any: + # search in worker first + if not self.is_local_deployment(): + workers = list(self._worker_address_to_worker.values()) + for worker in workers: + f = await worker.get_model_registration(model_type, model_name) + if f is not None: + return f + if model_type == "LLM": from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families @@ -705,6 +714,13 @@ async def query_engines_by_model_name(self, model_name: str): from ..model.llm.llm_family import LLM_ENGINES + # search in worker first + workers = list(self._worker_address_to_worker.values()) + for worker in workers: + res = await worker.query_engines_by_model_name(model_name) + if res is not None: + return res + if model_name not in LLM_ENGINES: raise ValueError(f"Model {model_name} not found") @@ -718,7 +734,13 @@ async def query_engines_by_model_name(self, model_name: str): return engine_params @log_async(logger=logger) - async def register_model(self, model_type: str, model: str, persist: bool): + async def register_model( + self, + model_type: str, + model: str, + persist: bool, + worker_ip: Optional[str] = None, + ): if model_type in self._custom_register_type_to_cls: ( model_spec_cls, @@ -727,10 +749,21 @@ async def register_model(self, model_type: str, model: str, persist: bool): generate_fn, ) = self._custom_register_type_to_cls[model_type] - if not self.is_local_deployment(): - workers = list(self._worker_address_to_worker.values()) - for worker in workers: - await worker.register_model(model_type, model, persist) + target_ip_worker_ref = ( + self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None + ) + if ( + worker_ip is not None + and not self.is_local_deployment() + and target_ip_worker_ref is None + ): + raise ValueError( + f"Worker ip address {worker_ip} is not in the cluster." + ) + + if target_ip_worker_ref: + await target_ip_worker_ref.register_model(model_type, model, persist) + return model_spec = model_spec_cls.parse_raw(model) try: @@ -738,6 +771,8 @@ async def register_model(self, model_type: str, model: str, persist: bool): await self._cache_tracker_ref.record_model_version( generate_fn(model_spec), self.address ) + except ValueError as e: + raise e except Exception as e: unregister_fn(model_spec.model_name, raise_error=False) raise e @@ -748,13 +783,14 @@ async def register_model(self, model_type: str, model: str, persist: bool): async def unregister_model(self, model_type: str, model_name: str): if model_type in self._custom_register_type_to_cls: _, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type] - unregister_fn(model_name) - await self._cache_tracker_ref.unregister_model_version(model_name) + unregister_fn(model_name, False) if not self.is_local_deployment(): workers = list(self._worker_address_to_worker.values()) for worker in workers: - await worker.unregister_model(model_name) + await worker.unregister_model(model_type, model_name) + + await self._cache_tracker_ref.unregister_model_version(model_name) else: raise ValueError(f"Unsupported model type: {model_type}") @@ -825,6 +861,14 @@ async def launch_builtin_model( download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None, **kwargs, ) -> str: + # search in worker first + if not self.is_local_deployment(): + workers = list(self._worker_address_to_worker.values()) + for worker in workers: + res = await worker.get_model_registration(model_type, model_name) + if res is not None: + worker_ip = worker.address.split(":")[0] + target_ip_worker_ref = ( self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None ) @@ -877,6 +921,7 @@ async def _launch_one_model(_replica_model_uid): ) replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx) nonlocal model_type + worker_ref = ( target_ip_worker_ref if target_ip_worker_ref is not None diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 286440e1a9..e36d32e99e 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -212,12 +212,14 @@ async def __post_create__(self): from ..model.audio import ( CustomAudioModelFamilyV1, + generate_audio_description, get_audio_model_descriptions, register_audio, unregister_audio, ) from ..model.embedding import ( CustomEmbeddingModelSpec, + generate_embedding_description, get_embedding_model_descriptions, register_embedding, unregister_embedding, @@ -230,36 +232,56 @@ async def __post_create__(self): ) from ..model.image import ( CustomImageModelFamilyV1, + generate_image_description, get_image_model_descriptions, register_image, unregister_image, ) from ..model.llm import ( CustomLLMFamilyV1, + generate_llm_description, get_llm_model_descriptions, register_llm, unregister_llm, ) from ..model.rerank import ( CustomRerankModelSpec, + generate_rerank_description, get_rerank_model_descriptions, register_rerank, unregister_rerank, ) self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore - "LLM": (CustomLLMFamilyV1, register_llm, unregister_llm), + "LLM": ( + CustomLLMFamilyV1, + register_llm, + unregister_llm, + generate_llm_description, + ), "embedding": ( CustomEmbeddingModelSpec, register_embedding, unregister_embedding, + generate_embedding_description, + ), + "rerank": ( + CustomRerankModelSpec, + register_rerank, + unregister_rerank, + generate_rerank_description, ), - "rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank), - "audio": (CustomAudioModelFamilyV1, register_audio, unregister_audio), "image": ( CustomImageModelFamilyV1, register_image, unregister_image, + generate_image_description, + ), + "audio": ( + CustomAudioModelFamilyV1, + register_audio, + unregister_audio, + generate_audio_description, ), "flexible": ( FlexibleModelSpec, @@ -526,17 +548,23 @@ def _check_model_is_valid(self, model_name: str, model_format: Optional[str]): raise ValueError(f"{model_name} model can't run on Darwin system.") @log_sync(logger=logger) - def register_model(self, model_type: str, model: str, persist: bool): + async def register_model(self, model_type: str, model: str, persist: bool): # TODO: centralized model registrations if model_type in self._custom_register_type_to_cls: ( model_spec_cls, register_fn, unregister_fn, + generate_fn, ) = self._custom_register_type_to_cls[model_type] model_spec = model_spec_cls.parse_raw(model) try: register_fn(model_spec, persist) + await self._cache_tracker_ref.record_model_version( + generate_fn(model_spec), self.address + ) + except ValueError as e: + raise e except Exception as e: unregister_fn(model_spec.model_name, raise_error=False) raise e @@ -544,14 +572,127 @@ def register_model(self, model_type: str, model: str, persist: bool): raise ValueError(f"Unsupported model type: {model_type}") @log_sync(logger=logger) - def unregister_model(self, model_type: str, model_name: str): + async def unregister_model(self, model_type: str, model_name: str): # TODO: centralized model registrations if model_type in self._custom_register_type_to_cls: - _, _, unregister_fn = self._custom_register_type_to_cls[model_type] - unregister_fn(model_name) + _, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type] + unregister_fn(model_name, False) else: raise ValueError(f"Unsupported model type: {model_type}") + @log_async(logger=logger) + async def list_model_registrations( + self, model_type: str, detailed: bool = False + ) -> List[Dict[str, Any]]: + def sort_helper(item): + assert isinstance(item["model_name"], str) + return item.get("model_name").lower() + + if model_type == "LLM": + from ..model.llm import get_user_defined_llm_families + + ret = [] + + for family in get_user_defined_llm_families(): + ret.append({"model_name": family.model_name, "is_builtin": False}) + + ret.sort(key=sort_helper) + return ret + elif model_type == "embedding": + from ..model.embedding.custom import get_user_defined_embeddings + + ret = [] + + for model_spec in get_user_defined_embeddings(): + ret.append({"model_name": model_spec.model_name, "is_builtin": False}) + + ret.sort(key=sort_helper) + return ret + elif model_type == "image": + from ..model.image.custom import get_user_defined_images + + ret = [] + + for model_spec in get_user_defined_images(): + ret.append({"model_name": model_spec.model_name, "is_builtin": False}) + + ret.sort(key=sort_helper) + return ret + elif model_type == "audio": + from ..model.audio.custom import get_user_defined_audios + + ret = [] + + for model_spec in get_user_defined_audios(): + ret.append({"model_name": model_spec.model_name, "is_builtin": False}) + + ret.sort(key=sort_helper) + return ret + elif model_type == "rerank": + from ..model.rerank.custom import get_user_defined_reranks + + ret = [] + + for model_spec in get_user_defined_reranks(): + ret.append({"model_name": model_spec.model_name, "is_builtin": False}) + + ret.sort(key=sort_helper) + return ret + else: + raise ValueError(f"Unsupported model type: {model_type}") + + @log_sync(logger=logger) + async def get_model_registration(self, model_type: str, model_name: str) -> Any: + if model_type == "LLM": + from ..model.llm import get_user_defined_llm_families + + for f in get_user_defined_llm_families(): + if f.model_name == model_name: + return f + elif model_type == "embedding": + from ..model.embedding.custom import get_user_defined_embeddings + + for f in get_user_defined_embeddings(): + if f.model_name == model_name: + return f + elif model_type == "image": + from ..model.image.custom import get_user_defined_images + + for f in get_user_defined_images(): + if f.model_name == model_name: + return f + elif model_type == "audio": + from ..model.audio.custom import get_user_defined_audios + + for f in get_user_defined_audios(): + if f.model_name == model_name: + return f + elif model_type == "rerank": + from ..model.rerank.custom import get_user_defined_reranks + + for f in get_user_defined_reranks(): + if f.model_name == model_name: + return f + return None + + @log_async(logger=logger) + async def query_engines_by_model_name(self, model_name: str): + from copy import deepcopy + + from ..model.llm.llm_family import LLM_ENGINES + + if model_name not in LLM_ENGINES: + return None + + # filter llm_class + engine_params = deepcopy(LLM_ENGINES[model_name]) + for engine in engine_params: + params = engine_params[engine] + for param in params: + del param["llm_class"] + + return engine_params + async def _get_model_ability(self, model: Any, model_type: str) -> List[str]: from ..model.llm.core import LLM diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 2fb84d95c9..fcff9b32ac 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -370,6 +370,9 @@ def worker( help="Type of model to register (default is 'LLM').", ) @click.option("--file", "-f", type=str, help="Path to the model configuration file.") +@click.option( + "--worker-ip", "-w", type=str, help="Specify the ip address of the worker." +) @click.option( "--persist", "-p", @@ -387,6 +390,7 @@ def register_model( endpoint: Optional[str], model_type: str, file: str, + worker_ip: str, persist: bool, api_key: Optional[str], ): @@ -400,6 +404,7 @@ def register_model( client.register_model( model_type=model_type, model=model, + worker_ip=worker_ip, persist=persist, ) diff --git a/xinference/model/image/custom.py b/xinference/model/image/custom.py index b002f6ded1..ff66ff8aa7 100644 --- a/xinference/model/image/custom.py +++ b/xinference/model/image/custom.py @@ -66,7 +66,7 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool): raise ValueError(f"Invalid model URI {model_uri}") persist_path = os.path.join( - XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_id}.json" + XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_name}.json" ) os.makedirs(os.path.dirname(persist_path), exist_ok=True) with open(persist_path, "w") as f: