From e3a947ebddfc53b5e8ec723c1f632c2b895edef1 Mon Sep 17 00:00:00 2001 From: "Xuye (Chris) Qin" Date: Thu, 11 Apr 2024 15:35:46 +0800 Subject: [PATCH] BUG: fix custom vision model (#1280) --- xinference/model/llm/llm_family.py | 15 ++++++ xinference/model/llm/pytorch/core.py | 50 ++++++++----------- xinference/model/llm/tests/test_llm_family.py | 14 ++++++ 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 15ff0db84c..a131a25169 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -199,6 +199,21 @@ def parse_raw( ) llm_spec.prompt_style = BUILTIN_LLM_PROMPT_STYLE[prompt_style_name] + # check model ability, registering LLM only provides generate and chat + # but for vision models, we add back the abilities so that + # gradio chat interface can be generated properly + if ( + llm_spec.model_family != "other" + and llm_spec.model_family + in { + family.model_name + for family in BUILTIN_LLM_FAMILIES + if "vision" in family.model_ability + } + and "vision" not in llm_spec.model_ability + ): + llm_spec.model_ability.append("vision") + return llm_spec diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 88ca44ee7b..8cac269289 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -42,6 +42,25 @@ logger = logging.getLogger(__name__) +NON_DEFAULT_MODEL_LIST: List[str] = [ + "baichuan-chat", + "baichuan-2-chat", + "vicuna-v1.3", + "falcon", + "falcon-instruct", + "chatglm", + "chatglm2", + "chatglm2-32k", + "chatglm2-128k", + "llama-2", + "llama-2-chat", + "internlm2-chat", + "qwen-vl-chat", + "OmniLMM", + "yi-vl-chat", + "deepseek-vl-chat", +] + class PytorchModel(LLM): def __init__( @@ -233,17 +252,7 @@ def match( if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: return False model_family = llm_family.model_family or llm_family.model_name - if model_family in [ - "baichuan-chat", - "vicuna-v1.3", - "falcon", - "falcon-instruct", - "chatglm", - "chatglm2", - "chatglm2-32k", - "llama-2", - "llama-2-chat", - ]: + if model_family in NON_DEFAULT_MODEL_LIST: return False if "generate" not in llm_family.model_ability: return False @@ -452,23 +461,8 @@ def match( ) -> bool: if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: return False - if llm_family.model_name in [ - "baichuan-chat", - "baichuan-2-chat", - "vicuna-v1.3", - "falcon", - "falcon-instruct", - "chatglm", - "chatglm2", - "chatglm2-32k", - "llama-2", - "llama-2-chat", - "internlm2-chat", - "qwen-vl-chat", - "OmniLMM", - "yi-vl-chat", - "deepseek-vl-chat", - ]: + model_family = llm_family.model_family or llm_family.model_name + if model_family in NON_DEFAULT_MODEL_LIST: return False if "chat" not in llm_family.model_ability: return False diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py index fe093822a2..b1b6662dc1 100644 --- a/xinference/model/llm/tests/test_llm_family.py +++ b/xinference/model/llm/tests/test_llm_family.py @@ -992,6 +992,20 @@ def test_parse_prompt_style(): model_spec = CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf8")) assert model_spec.model_name == llm_family.model_name + # test vision + llm_family = CustomLLMFamilyV1( + version=1, + model_type="LLM", + model_name="test_LLM", + model_lang=["en"], + model_ability=["chat", "generate"], + model_specs=[hf_spec, ms_spec], + model_family="qwen-vl-chat", + prompt_style="qwen-vl-chat", + ) + model_spec = CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf-8")) + assert "vision" in model_spec.model_ability + # error: missing model_family llm_family = CustomLLMFamilyV1( version=1,