Skip to content

Commit

Permalink
BUG: Make context_length optional in model family (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bojun-Feng authored Sep 1, 2023
1 parent 811b182 commit f43e4b8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
4 changes: 2 additions & 2 deletions doc/source/models/custom.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ Define a custom model based on the following template:
"model_id": "TheBloke/Llama-2-7B-GGML",
"model_file_name_template": "llama-2-7b.ggmlv3.{quantization}.bin"
}
],
]
}
* model_name: A string defining the name of the model. The name must start with a letter or a digit and can only contain letters, digits, underscores, or dashes.
* context_length: An integer that specifies the maximum context size the model can accommodate, encompassing both the input and output lengths. It defines the boundary within which the model is designed to function optimally.
* context_length: context_length: An optional integer that specifies the maximum context size the model was trained to accommodate, encompassing both the input and output lengths. If not defined, the default value is 2048 tokens (~1,500 words).
* model_lang: A list of strings representing the supported languages for the model. Example: ["en"], which means that the model supports English.
* model_ability: A list of strings defining the abilities of the model. It could include options like "embed", "generate", and "chat". In this case, the model has the ability to "generate".
* model_specs: An array of objects defining the specifications of the model. These include:
Expand Down
3 changes: 2 additions & 1 deletion xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__)

MAX_ATTEMPTS = 3
DEFAULT_CONTEXT_LENGTH = 2048


class GgmlLLMSpecV1(BaseModel):
Expand Down Expand Up @@ -60,7 +61,7 @@ class PromptStyleV1(BaseModel):

class LLMFamilyV1(BaseModel):
version: Literal[1]
context_length: int
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH
model_name: str
model_lang: List[Literal["en", "zh"]]
model_ability: List[Literal["embed", "generate", "chat"]]
Expand Down
14 changes: 13 additions & 1 deletion xinference/model/llm/tests/test_llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def test_serialize_llm_family_v1():
)
llm_family = LLMFamilyV1(
version=1,
context_length=2048,
model_type="LLM",
model_name="TestModel",
model_lang=["en"],
Expand All @@ -132,6 +131,19 @@ def test_serialize_llm_family_v1():
expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_specs": [{"model_format": "ggmlv3", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_uri": null}, {"model_format": "pytorch", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}"""
assert json.loads(llm_family.json()) == json.loads(expected)

llm_family_context_length = LLMFamilyV1(
version=1,
context_length=2048,
model_type="LLM",
model_name="TestModel",
model_lang=["en"],
model_ability=["embed", "generate"],
model_specs=[ggml_spec, pytorch_spec],
prompt_style=prompt_style,
)

assert json.loads(llm_family_context_length.json()) == json.loads(expected)


def test_builtin_llm_families():
import os
Expand Down

0 comments on commit f43e4b8

Please sign in to comment.