From f43e4b84ce32cf3e9fb50c9c2fad960e0d3b5fb8 Mon Sep 17 00:00:00 2001 From: Bojun-Feng <102875484+Bojun-Feng@users.noreply.github.com> Date: Fri, 1 Sep 2023 04:58:30 -0500 Subject: [PATCH] BUG: Make context_length optional in model family (#394) --- doc/source/models/custom.rst | 4 ++-- xinference/model/llm/llm_family.py | 3 ++- xinference/model/llm/tests/test_llm_family.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/doc/source/models/custom.rst b/doc/source/models/custom.rst index 5da532884c..3a89400165 100644 --- a/doc/source/models/custom.rst +++ b/doc/source/models/custom.rst @@ -44,11 +44,11 @@ Define a custom model based on the following template: "model_id": "TheBloke/Llama-2-7B-GGML", "model_file_name_template": "llama-2-7b.ggmlv3.{quantization}.bin" } - ], + ] } * model_name: A string defining the name of the model. The name must start with a letter or a digit and can only contain letters, digits, underscores, or dashes. -* context_length: An integer that specifies the maximum context size the model can accommodate, encompassing both the input and output lengths. It defines the boundary within which the model is designed to function optimally. +* context_length: context_length: An optional integer that specifies the maximum context size the model was trained to accommodate, encompassing both the input and output lengths. If not defined, the default value is 2048 tokens (~1,500 words). * model_lang: A list of strings representing the supported languages for the model. Example: ["en"], which means that the model supports English. * model_ability: A list of strings defining the abilities of the model. It could include options like "embed", "generate", and "chat". In this case, the model has the ability to "generate". * model_specs: An array of objects defining the specifications of the model. These include: diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 7489fe5542..b9ae88a80c 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) MAX_ATTEMPTS = 3 +DEFAULT_CONTEXT_LENGTH = 2048 class GgmlLLMSpecV1(BaseModel): @@ -60,7 +61,7 @@ class PromptStyleV1(BaseModel): class LLMFamilyV1(BaseModel): version: Literal[1] - context_length: int + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH model_name: str model_lang: List[Literal["en", "zh"]] model_ability: List[Literal["embed", "generate", "chat"]] diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py index 2f23f4acc3..08a3dfce7d 100644 --- a/xinference/model/llm/tests/test_llm_family.py +++ b/xinference/model/llm/tests/test_llm_family.py @@ -120,7 +120,6 @@ def test_serialize_llm_family_v1(): ) llm_family = LLMFamilyV1( version=1, - context_length=2048, model_type="LLM", model_name="TestModel", model_lang=["en"], @@ -132,6 +131,19 @@ def test_serialize_llm_family_v1(): expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_specs": [{"model_format": "ggmlv3", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_uri": null}, {"model_format": "pytorch", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}""" assert json.loads(llm_family.json()) == json.loads(expected) + llm_family_context_length = LLMFamilyV1( + version=1, + context_length=2048, + model_type="LLM", + model_name="TestModel", + model_lang=["en"], + model_ability=["embed", "generate"], + model_specs=[ggml_spec, pytorch_spec], + prompt_style=prompt_style, + ) + + assert json.loads(llm_family_context_length.json()) == json.loads(expected) + def test_builtin_llm_families(): import os