Skip to content

Commit

Permalink
BUG: cache status missing for model id with quantization placeholder (#…
Browse files Browse the repository at this point in the history
…1849)

Co-authored-by: ChengjieLi <[email protected]>
  • Loading branch information
Zihann73 and ChengjieLi28 authored Jul 12, 2024
1 parent 9bb548a commit e916d05
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 59 deletions.
99 changes: 54 additions & 45 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
import os
import platform
import shutil
from threading import Lock
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
Expand Down Expand Up @@ -541,15 +540,20 @@ def _get_cache_dir_for_model_mem(
def _get_cache_dir(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
quantization: Optional[str] = None,
create_if_not_exist=True,
):
# If the model id contains quantization, then we should give each
# quantization a dedicated cache dir.
quant_suffix = ""
for q in llm_spec.quantizations:
if llm_spec.model_id and q in llm_spec.model_id:
quant_suffix = q
break
if llm_spec.model_id and "{" in llm_spec.model_id and quantization is not None:
quant_suffix = quantization
else:
for q in llm_spec.quantizations:
if llm_spec.model_id and q in llm_spec.model_id:
quant_suffix = q
break

cache_dir_name = (
f"{llm_family.model_name}-{llm_spec.model_format}"
f"-{llm_spec.model_size_in_billions}b"
Expand Down Expand Up @@ -900,6 +904,7 @@ def _check_revision(
llm_spec: "LLMSpecV1",
builtin: list,
meta_path: str,
quantization: Optional[str] = None,
) -> bool:
for family in builtin:
if llm_family.model_name == family.model_name:
Expand All @@ -908,59 +913,63 @@ def _check_revision(
if (
spec.model_format == "pytorch"
and spec.model_size_in_billions == llm_spec.model_size_in_billions
and (quantization is None or quantization in spec.quantizations)
):
return valid_model_revision(meta_path, spec.model_revision)
return False


def get_cache_status(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
llm_family: LLMFamilyV1, llm_spec: "LLMSpecV1", quantization: Optional[str] = None
) -> Union[bool, List[bool]]:
"""
When calling this function from above, `llm_family` is constructed only from BUILTIN_LLM_FAMILIES,
so we should check both huggingface and modelscope cache files.
Checks if a model's cache status is available based on the model format and quantization.
Supports different directories and model formats.
"""
cache_dir = _get_cache_dir(llm_family, llm_spec, create_if_not_exist=False)
# check revision for pytorch model
if llm_spec.model_format == "pytorch":
hf_meta_path = _get_meta_path(cache_dir, "pytorch", "huggingface", "none")
ms_meta_path = _get_meta_path(cache_dir, "pytorch", "modelscope", "none")
revisions = [
_check_revision(llm_family, llm_spec, BUILTIN_LLM_FAMILIES, hf_meta_path),
_check_revision(
llm_family, llm_spec, BUILTIN_MODELSCOPE_LLM_FAMILIES, ms_meta_path
),
]
return any(revisions)
# just check meta file for ggml and gptq model
elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]:
ret = []
for q in llm_spec.quantizations:
assert q is not None
hf_meta_path = _get_meta_path(
cache_dir, llm_spec.model_format, "huggingface", q
)
ms_meta_path = _get_meta_path(
cache_dir, llm_spec.model_format, "modelscope", q
)
results = [os.path.exists(hf_meta_path), os.path.exists(ms_meta_path)]
ret.append(any(results))
return ret
else:
raise ValueError(f"Unsupported model format: {llm_spec.model_format}")

def check_file_status(meta_path: str) -> bool:
return os.path.exists(meta_path)

def _is_linux():
return platform.system() == "Linux"
def check_revision_status(
meta_path: str, families: list, quantization: Optional[str] = None
) -> bool:
return _check_revision(llm_family, llm_spec, families, meta_path, quantization)

def handle_quantization(q: Union[str, None]) -> bool:
specific_cache_dir = _get_cache_dir(
llm_family, llm_spec, q, create_if_not_exist=False
)
meta_paths = {
"huggingface": _get_meta_path(
specific_cache_dir, llm_spec.model_format, "huggingface", q
),
"modelscope": _get_meta_path(
specific_cache_dir, llm_spec.model_format, "modelscope", q
),
}
if llm_spec.model_format == "pytorch":
return check_revision_status(
meta_paths["huggingface"], BUILTIN_LLM_FAMILIES, q
) or check_revision_status(
meta_paths["modelscope"], BUILTIN_MODELSCOPE_LLM_FAMILIES, q
)
else:
return check_file_status(meta_paths["huggingface"]) or check_file_status(
meta_paths["modelscope"]
)

def _has_cuda_device():
# `cuda_count` method already contains the logic for the
# number of GPUs specified by `CUDA_VISIBLE_DEVICES`.
from ...utils import cuda_count

return cuda_count() > 0
if llm_spec.model_id and "{" in llm_spec.model_id:
return (
[handle_quantization(q) for q in llm_spec.quantizations]
if quantization is None
else handle_quantization(quantization)
)
else:
return (
[handle_quantization(q) for q in llm_spec.quantizations]
if llm_spec.model_format != "pytorch"
else handle_quantization(None)
)


def get_user_defined_llm_families():
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/llm/llm_family_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -4143,7 +4143,7 @@
"zh"
],
"model_ability": [
"generate"
"chat"
],
"model_description": "Aquila2-chat series models are the chat models",
"model_specs": [
Expand Down
6 changes: 4 additions & 2 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,10 @@ def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tool
def get_file_location(
llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
) -> Tuple[str, bool]:
cache_dir = _get_cache_dir(llm_family, spec, create_if_not_exist=False)
cache_status = get_cache_status(llm_family, spec)
cache_dir = _get_cache_dir(
llm_family, spec, quantization, create_if_not_exist=False
)
cache_status = get_cache_status(llm_family, spec, quantization)
if isinstance(cache_status, list):
is_cached = None
for q, cs in zip(spec.quantizations, cache_status):
Expand Down
6 changes: 3 additions & 3 deletions xinference/web/ui/src/scenes/launch_model/launchLLM.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ const LaunchLLM = ({ gpuAvailable }) => {
}

const filterCache = (spec) => {
if (spec.model_format === 'pytorch') {
return spec.cache_status && spec.cache_status === true
if (Array.isArray(spec.cache_status)) {
return spec.cache_status.some((cs) => cs)
} else {
return spec.cache_status && spec.cache_status.some((cs) => cs)
return spec.cache_status === true
}
}

Expand Down
18 changes: 10 additions & 8 deletions xinference/web/ui/src/scenes/launch_model/modelCard.js
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ const ModelCard = ({
}

const isCached = (spec) => {
if (spec.model_format === 'pytorch') {
return spec.cache_status && spec.cache_status === true
if (Array.isArray(spec.cache_status)) {
return spec.cache_status.some((cs) => cs)
} else {
return spec.cache_status && spec.cache_status.some((cs) => cs)
return spec.cache_status === true
}
}

Expand Down Expand Up @@ -1196,7 +1196,7 @@ const ModelCard = ({
onChange={(e) => setQuantization(e.target.value)}
label="Quantization"
>
{quantizationOptions.map((quant, index) => {
{quantizationOptions.map((quant) => {
const specs = modelData.model_specs
.filter((spec) => spec.model_format === modelFormat)
.filter(
Expand All @@ -1205,10 +1205,12 @@ const ModelCard = ({
convertModelSize(modelSize)
)

const cached =
modelFormat === 'pytorch'
? specs[0]?.cache_status ?? false === true
: specs[0]?.cache_status?.[index] ?? false === true
const spec = specs.find((s) => {
return s.quantizations.includes(quant)
})
const cached = Array.isArray(spec.cache_status)
? spec.cache_status[spec.quantizations.indexOf(quant)]
: spec.cache_status

const displayedQuant = cached
? quant + ' (cached)'
Expand Down

0 comments on commit e916d05

Please sign in to comment.