Skip to content

Commit

Permalink
FEAT: LLaMA-2 (xorbitsai#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Jul 19, 2023
1 parent aabf3aa commit 1fe3bb2
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 30 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,23 @@ To view the builtin models, run the following command:
$ xinference list --all
```

| Name | Type | Language | Format | Size (in billions) | Quantization |
| -------------------- |------------------|----------|--------|--------------------|----------------------------------------|
| baichuan | Foundation Model | en, zh | ggmlv3 | 7 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| chatglm | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| chatglm2 | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| wizardlm-v1.0 | SFT Model | en | ggmlv3 | 7, 13, 33 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| wizardlm-v1.1 | SFT Model | en | ggmlv3 | 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| vicuna-v1.3 | SFT Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| orca | SFT Model | en | ggmlv3 | 3, 7, 13 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| Name | Type | Language | Format | Size (in billions) | Quantization |
|---------------|------------------|----------|---------|--------------------|-----------------------------------------|
| llama-2 | Foundation Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| baichuan | Foundation Model | en, zh | ggmlv3 | 7 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| llama-2-chat | RLHF Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| chatglm | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| chatglm2 | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| wizardlm-v1.0 | SFT Model | en | ggmlv3 | 7, 13, 33 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| wizardlm-v1.1 | SFT Model | en | ggmlv3 | 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| vicuna-v1.3 | SFT Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| orca | SFT Model | en | ggmlv3 | 3, 7, 13 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |


**NOTE**:
- Xinference will download models automatically for you, and by default the models will be saved under `${USER}/.xinference/cache`.
- Foundation models only provide interface `generate`.
- SFT models provide both `generate` and `chat`.
- RLHF and SFT models provide both `generate` and `chat`.

## Roadmap
Xinference is currently under active development. Here's a roadmap outlining our planned
Expand Down
22 changes: 12 additions & 10 deletions README_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,23 @@ model.chat(
$ xinference list --all
```

| Name | Type | Language | Format | Size (in billions) | Quantization |
| -------------------- |------------------|----------|--------|--------------------|----------------------------------------|
| baichuan | Foundation Model | en, zh | ggmlv3 | 7 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| chatglm | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| chatglm2 | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| wizardlm-v1.0 | SFT Model | en | ggmlv3 | 7, 13, 33 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| wizardlm-v1.1 | SFT Model | en | ggmlv3 | 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| vicuna-v1.3 | SFT Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| orca | SFT Model | en | ggmlv3 | 3, 7, 13 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| Name | Type | Language | Format | Size (in billions) | Quantization |
|---------------|------------------|----------|---------|--------------------|-----------------------------------------|
| llama-2 | Foundation Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| baichuan | Foundation Model | en, zh | ggmlv3 | 7 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| llama-2-chat | RLHF Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| chatglm | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| chatglm2 | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |
| wizardlm-v1.0 | SFT Model | en | ggmlv3 | 7, 13, 33 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| wizardlm-v1.1 | SFT Model | en | ggmlv3 | 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| vicuna-v1.3 | SFT Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |
| orca | SFT Model | en | ggmlv3 | 3, 7, 13 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |


**注意**:
- Xinference 会自动为你下载模型,默认的模型存放路径为 `${USER}/.xinference/cache`
- 基础模型仅提供 `generate` 接口.
- SFT 模型 提供 `generate``chat` 接口。
- RLHF 与 SFT 模型 提供 `generate``chat` 接口。

## 近期开发计划
Xinference 目前正在快速迭代。我们近期的开发计划包括:
Expand Down
10 changes: 2 additions & 8 deletions xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,7 @@ def get_model(self, model_uid: str) -> "ModelHandle":

if model_spec.model_name == "chatglm" or model_spec.model_name == "chatglm2":
return ChatglmCppChatModelHandle(model_ref, self._isolation)
elif (
model_spec.model_name == "baichuan"
or model_spec.model_name == "baichuan-base"
):
elif model_spec.model_name in ["baichuan", "baichuan-base", "llama-2"]:
return GenerateModelHandle(model_ref, self._isolation)
else:
return ChatModelHandle(model_ref, self._isolation)
Expand Down Expand Up @@ -407,10 +404,7 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
or model_spec["model_name"] == "chatglm2"
):
return RESTfulChatglmCppChatModelHandle(model_uid, self.base_url)
elif (
model_spec["model_name"] == "baichuan"
or model_spec["model_name"] == "baichuan-base"
):
elif model_spec["model_name"] in ["baichuan", "baichuan-base", "llama-2"]:
return RESTfulGenerateModelHandle(model_uid, self.base_url)
else:
return RESTfulChatModelHandle(model_uid, self.base_url)
2 changes: 1 addition & 1 deletion xinference/core/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
MODEL_TO_FAMILIES = dict(
(model_family.model_name, model_family)
for model_family in MODEL_FAMILIES
if model_family.model_name not in ["baichuan", "baichuan-base"]
if model_family.model_name not in ["baichuan", "baichuan-base", "llama-2"]
)


Expand Down
61 changes: 61 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def install():
from .. import MODEL_FAMILIES, ModelFamily
from .chatglm import ChatglmCppChatModel
from .core import LlamaCppModel
from .llama2 import Llama2ChatGgml
from .orca import OrcaMiniGgml
from .pytorch.baichuan import BaichuanPytorch, BaichuanPytorchChat
from .pytorch.vicuna import VicunaCensoredPytorch
Expand Down Expand Up @@ -213,6 +214,66 @@ def install():
)
)

llama2_chat_url_generator = lambda model_size, quantization: (
f"https://huggingface.co/TheBloke/Llama-2-{model_size}B-chat-GGML/resolve/main/llama-2-"
f"{model_size}b-chat.ggmlv3.{quantization}.bin"
)
MODEL_FAMILIES.append(
ModelFamily(
model_name="llama-2-chat",
model_sizes_in_billions=[7, 13],
model_format="ggmlv3",
quantizations=[
"q2_K",
"q3_K_L",
"q3_K_M",
"q3_K_S",
"q4_0",
"q4_1",
"q4_K_M",
"q4_K_S",
"q5_0",
"q5_1",
"q5_K_M",
"q5_K_S",
"q6_K",
"q8_0",
],
url_generator=llama2_chat_url_generator,
cls=Llama2ChatGgml,
)
)

llama2_url_generator = lambda model_size, quantization: (
f"https://huggingface.co/TheBloke/Llama-2-{model_size}B-GGML/resolve/main/llama-2-"
f"{model_size}b.ggmlv3.{quantization}.bin"
)
MODEL_FAMILIES.append(
ModelFamily(
model_name="llama-2",
model_sizes_in_billions=[7, 13],
model_format="ggmlv3",
quantizations=[
"q2_K",
"q3_K_L",
"q3_K_M",
"q3_K_S",
"q4_0",
"q4_1",
"q4_K_M",
"q4_K_S",
"q5_0",
"q5_1",
"q5_K_M",
"q5_K_S",
"q6_K",
"q8_0",
],
url_generator=llama2_url_generator,
cls=LlamaCppModel,
)
)

pytorch_baichuan_name_generator = lambda model_size, quantization: (
f"baichuan-inc/Baichuan-{model_size}B"
)
Expand Down
10 changes: 10 additions & 0 deletions xinference/model/llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,23 @@ def __init__(
sep: str,
user_name: str,
assistant_name: str,
stop: Optional[Union[str, List[str]]] = None,
llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
):
super().__init__(model_uid, model_spec, model_path, llamacpp_model_config)
self._system_prompt: str = system_prompt
self._sep: str = sep
self._user_name: str = user_name
self._assistant_name: str = assistant_name
self._stop = stop

def _sanitize_generate_config(
self, generate_config: Optional[LlamaCppGenerateConfig]
) -> LlamaCppGenerateConfig:
generate_config = super()._sanitize_generate_config(generate_config)
if self._stop:
generate_config["stop"] = self._stop
return generate_config

def chat(
self,
Expand Down
51 changes: 51 additions & 0 deletions xinference/model/llm/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from .. import ModelSpec
from .core import LlamaCppChatModel, LlamaCppModelConfig


class Llama2ChatGgml(LlamaCppChatModel):
_system_prompt = (
"System: You are a helpful, respectful and honest assistant. Always answer as helpfully as"
" possible, while being safe. Your answers should not include any harmful, unethical,"
" racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses "
"are socially unbiased and positive in nature. If a question does not make any sense, or "
"is not factually coherent, explain why instead of answering something not correct. If you "
"don't know the answer to a question, please don't share false information."
)
_sep = "\n"
_user_name = "User"
_assistant_name = "Assistant"
_stop = "\n"

def __init__(
self,
model_uid: str,
model_spec: "ModelSpec",
model_path: str,
llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
):
super().__init__(
model_uid,
model_spec,
model_path,
system_prompt=self._system_prompt,
sep=self._sep,
user_name=self._user_name,
assistant_name=self._assistant_name,
llamacpp_model_config=llamacpp_model_config,
)
5 changes: 4 additions & 1 deletion xinference/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from ..client import Client, RESTfulClient
from ..client import ChatModelHandle, Client, RESTfulChatModelHandle, RESTfulClient


@pytest.mark.asyncio
Expand All @@ -29,6 +29,7 @@ async def test_sync_client(setup):
assert len(client.list_models()) == 1

model = client.get_model(model_uid=model_uid)
assert isinstance(model, ChatModelHandle)

with pytest.raises(RuntimeError):
model.create_embedding("The food was delicious and the waiter...")
Expand Down Expand Up @@ -67,9 +68,11 @@ async def test_RESTful_client(setup):
assert len(client.list_models()) == 1

model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulChatModelHandle)

with pytest.raises(RuntimeError):
model = client.get_model(model_uid="test")
assert isinstance(model, RESTfulChatModelHandle)

with pytest.raises(RuntimeError):
completion = model.generate({"max_tokens": 64})
Expand Down

0 comments on commit 1fe3bb2

Please sign in to comment.