Skip to content

Commit

Permalink
BUG: fix custom vision model (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Apr 11, 2024
1 parent a7f0c3b commit e3a947e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
15 changes: 15 additions & 0 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ def parse_raw(
)
llm_spec.prompt_style = BUILTIN_LLM_PROMPT_STYLE[prompt_style_name]

# check model ability, registering LLM only provides generate and chat
# but for vision models, we add back the abilities so that
# gradio chat interface can be generated properly
if (
llm_spec.model_family != "other"
and llm_spec.model_family
in {
family.model_name
for family in BUILTIN_LLM_FAMILIES
if "vision" in family.model_ability
}
and "vision" not in llm_spec.model_ability
):
llm_spec.model_ability.append("vision")

return llm_spec


Expand Down
50 changes: 22 additions & 28 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@

logger = logging.getLogger(__name__)

NON_DEFAULT_MODEL_LIST: List[str] = [
"baichuan-chat",
"baichuan-2-chat",
"vicuna-v1.3",
"falcon",
"falcon-instruct",
"chatglm",
"chatglm2",
"chatglm2-32k",
"chatglm2-128k",
"llama-2",
"llama-2-chat",
"internlm2-chat",
"qwen-vl-chat",
"OmniLMM",
"yi-vl-chat",
"deepseek-vl-chat",
]


class PytorchModel(LLM):
def __init__(
Expand Down Expand Up @@ -233,17 +252,7 @@ def match(
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
return False
model_family = llm_family.model_family or llm_family.model_name
if model_family in [
"baichuan-chat",
"vicuna-v1.3",
"falcon",
"falcon-instruct",
"chatglm",
"chatglm2",
"chatglm2-32k",
"llama-2",
"llama-2-chat",
]:
if model_family in NON_DEFAULT_MODEL_LIST:
return False
if "generate" not in llm_family.model_ability:
return False
Expand Down Expand Up @@ -452,23 +461,8 @@ def match(
) -> bool:
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
return False
if llm_family.model_name in [
"baichuan-chat",
"baichuan-2-chat",
"vicuna-v1.3",
"falcon",
"falcon-instruct",
"chatglm",
"chatglm2",
"chatglm2-32k",
"llama-2",
"llama-2-chat",
"internlm2-chat",
"qwen-vl-chat",
"OmniLMM",
"yi-vl-chat",
"deepseek-vl-chat",
]:
model_family = llm_family.model_family or llm_family.model_name
if model_family in NON_DEFAULT_MODEL_LIST:
return False
if "chat" not in llm_family.model_ability:
return False
Expand Down
14 changes: 14 additions & 0 deletions xinference/model/llm/tests/test_llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,20 @@ def test_parse_prompt_style():
model_spec = CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf8"))
assert model_spec.model_name == llm_family.model_name

# test vision
llm_family = CustomLLMFamilyV1(
version=1,
model_type="LLM",
model_name="test_LLM",
model_lang=["en"],
model_ability=["chat", "generate"],
model_specs=[hf_spec, ms_spec],
model_family="qwen-vl-chat",
prompt_style="qwen-vl-chat",
)
model_spec = CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf-8"))
assert "vision" in model_spec.model_ability

# error: missing model_family
llm_family = CustomLLMFamilyV1(
version=1,
Expand Down

0 comments on commit e3a947e

Please sign in to comment.