diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index cc7b452694..2385c48477 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -143,6 +143,7 @@ def _install(): ) from .transformers.deepseek_vl import DeepSeekVLChatModel from .transformers.glm4v import Glm4VModel + from .transformers.glm_edge_v import GlmEdgeVModel from .transformers.intern_vl import InternVLChatModel from .transformers.internlm2 import Internlm2PytorchChatModel from .transformers.minicpmv25 import MiniCPMV25Model @@ -193,6 +194,7 @@ def _install(): DeepSeekV2PytorchModel, DeepSeekV2PytorchChatModel, OptPytorchModel, + GlmEdgeVModel, ] ) if OmniLMMModel: # type: ignore diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 93da847654..9f53f45ae8 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -8596,5 +8596,232 @@ "<|im_start|>", "<|im_end|>" ] + }, + { + "version": 1, + "context_length": 8192, + "model_name": "glm-edge-chat", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat" + ], + "model_description": "The GLM-Edge series is our attempt to face the end-side real-life scenarios, which consists of two sizes of large-language dialogue models and multimodal comprehension models (GLM-Edge-1.5B-Chat, GLM-Edge-4B-Chat, GLM-Edge-V-2B, GLM-Edge-V-5B). Among them, the 1.5B / 2B model is mainly for platforms such as mobile phones and cars, and the 4B / 5B model is mainly for platforms such as PCs.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "1_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "THUDM/glm-edge-1.5b-chat" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "4", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "THUDM/glm-edge-4b-chat" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "1_5", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-1.5b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "1_5", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-1.5B-chat-{quantization}.gguf", + "model_id": "THUDM/glm-edge-1.5b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "4", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-4b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "4", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-4B-chat-{quantization}.gguf", + "model_id": "THUDM/glm-edge-4b-chat-gguf" + } + ], + "chat_template": "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}", + "stop_token_ids": [ + 59246, + 59253, + 59255 + ], + "stop": [ + "<|endoftext|>", + "<|user|>", + "<|observation|>" + ] + }, + { + "version": 1, + "context_length": 8192, + "model_name": "glm-edge-v", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat", + "vision" + ], + "model_description": "The GLM-Edge series is our attempt to face the end-side real-life scenarios, which consists of two sizes of large-language dialogue models and multimodal comprehension models (GLM-Edge-1.5B-Chat, GLM-Edge-4B-Chat, GLM-Edge-V-2B, GLM-Edge-V-5B). Among them, the 1.5B / 2B model is mainly for platforms such as mobile phones and cars, and the 4B / 5B model is mainly for platforms such as PCs.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "2", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "THUDM/glm-edge-v-2b" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "THUDM/glm-edge-v-5b" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-v-2B-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "f16" + ], + "model_file_name_template": "mmproj-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-5b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-v-5B-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-5b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "f16" + ], + "model_file_name_template": "mmproj-model-{quantization}.gguf", + "model_id": "THUDM/glm-edge-v-5b-gguf" + } + ], + "chat_template": "{% for item in messages %}{% if item['role'] != 'system' %}<|{{ item['role'] }}|>\n{% for content in item['content'] %}{% if content['type'] == 'image' %}{% for _ in range(578) %}<|begin_of_image|>{% endfor %}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}", + "stop_token_ids": [ + 59246, + 59253, + 59255 + ], + "stop": [ + "<|endoftext|>", + "<|user|>", + "<|observation|>" + ] } ] diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index d225b2abcb..c9e670a440 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -6334,5 +6334,246 @@ "<|im_start|>", "<|im_end|>" ] + }, + { + "version": 1, + "context_length": 8192, + "model_name": "glm-edge-chat", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat" + ], + "model_description": "The GLM-Edge series is our attempt to face the end-side real-life scenarios, which consists of two sizes of large-language dialogue models and multimodal comprehension models (GLM-Edge-1.5B-Chat, GLM-Edge-4B-Chat, GLM-Edge-V-2B, GLM-Edge-V-5B). Among them, the 1.5B / 2B model is mainly for platforms such as mobile phones and cars, and the 4B / 5B model is mainly for platforms such as PCs.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "1_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "ZhipuAI/glm-edge-1.5b-chat", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "4", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "ZhipuAI/glm-edge-4b-chat", + "model_hub": "modelscope" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "1_5", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-1.5b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "1_5", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-1.5B-chat-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-1.5b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "4", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-4b-chat-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "4", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-4B-chat-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-4b-chat-gguf" + } + ], + "chat_template": "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}", + "stop_token_ids": [ + 59246, + 59253, + 59255 + ], + "stop": [ + "<|endoftext|>", + "<|user|>", + "<|observation|>" + ] + }, + { + "version": 1, + "context_length": 8192, + "model_name": "glm-edge-v", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat", + "vision" + ], + "model_description": "The GLM-Edge series is our attempt to face the end-side real-life scenarios, which consists of two sizes of large-language dialogue models and multimodal comprehension models (GLM-Edge-1.5B-Chat, GLM-Edge-4B-Chat, GLM-Edge-V-2B, GLM-Edge-V-5B). Among them, the 1.5B / 2B model is mainly for platforms such as mobile phones and cars, and the 4B / 5B model is mainly for platforms such as PCs.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "2", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "ZhipuAI/glm-edge-v-2b", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "ZhipuAI/glm-edge-v-5b", + "model_hub": "modelscope" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-v-2B-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "2", + "quantizations": [ + "f16" + ], + "model_file_name_template": "mmproj-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-2b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "Q4_0", + "Q4_1", + "Q4_K", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_1", + "Q5_K", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_file_name_template": "ggml-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-5b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "F16" + ], + "model_file_name_template": "glm-edge-v-5B-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-5b-gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "5", + "quantizations": [ + "f16" + ], + "model_file_name_template": "mmproj-model-{quantization}.gguf", + "model_hub": "modelscope", + "model_id": "ZhipuAI/glm-edge-v-5b-gguf" + } + ], + "chat_template": "{% for item in messages %}{% if item['role'] != 'system' %}<|{{ item['role'] }}|>\n{% for content in item['content'] %}{% if content['type'] == 'image' %}{% for _ in range(578) %}<|begin_of_image|>{% endfor %}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}", + "stop_token_ids": [ + 59246, + 59253, + 59255 + ], + "stop": [ + "<|endoftext|>", + "<|user|>", + "<|observation|>" + ] } ] diff --git a/xinference/model/llm/transformers/core.py b/xinference/model/llm/transformers/core.py index 9d48c6f005..1494ec88ad 100644 --- a/xinference/model/llm/transformers/core.py +++ b/xinference/model/llm/transformers/core.py @@ -68,6 +68,7 @@ "deepseek-v2-chat", "deepseek-v2.5", "deepseek-v2-chat-0628", + "glm-edge-v", ] diff --git a/xinference/model/llm/transformers/glm_edge_v.py b/xinference/model/llm/transformers/glm_edge_v.py new file mode 100644 index 0000000000..e9fab5a58d --- /dev/null +++ b/xinference/model/llm/transformers/glm_edge_v.py @@ -0,0 +1,230 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import uuid +from concurrent.futures import ThreadPoolExecutor +from threading import Thread +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch + +from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk +from ...utils import select_device +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from ..utils import ( + _decode_image_without_rgb, + generate_chat_completion, + generate_completion_chunk, +) +from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean + +logger = logging.getLogger(__name__) + + +class GlmEdgeVModel(PytorchChatModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._device = None + self._tokenizer = None + self._model = None + self._processor = None + + @classmethod + def match( + cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str + ) -> bool: + family = model_family.model_family or model_family.model_name + if "glm-edge-v" in family.lower(): + return True + return False + + def load(self): + from transformers import AutoImageProcessor, AutoModelForCausalLM, AutoTokenizer + + device = self._pytorch_model_config.get("device", "auto") + self._device = select_device(device) + + kwargs = {"device_map": self._device} + quantization = self.quantization + + # referenced from PytorchModel.load + if quantization != "none": + if self._device == "cuda" and self._is_linux(): + kwargs["device_map"] = "auto" + if quantization == "4-bit": + kwargs["load_in_4bit"] = True + elif quantization == "8-bit": + kwargs["load_in_8bit"] = True + else: + raise ValueError( + f"Quantization {quantization} is not supported in temporary" + ) + else: + if quantization != "8-bit": + raise ValueError( + f"Only 8-bit quantization is supported if it is not linux system or cuda device" + ) + + processor = AutoImageProcessor.from_pretrained( + self.model_path, trust_remote_code=True + ) + self._processor = processor + + model = AutoModelForCausalLM.from_pretrained( + self.model_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + + self._model = model + + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) + self._tokenizer = tokenizer + + @staticmethod + def _get_processed_msgs( + messages: List[Dict], + ) -> Tuple[List[Dict[str, Any]], List[Any]]: + res = [] + img = [] + for message in messages: + role = message["role"] + content = message["content"] + if isinstance(content, str): + res.append({"role": role, "content": content}) + else: + texts = [] + image_urls = [] + for c in content: + c_type = c.get("type") + if c_type == "text": + texts.append(c["text"]) + else: + assert ( + c_type == "image_url" + ), "Please follow the image input of the OpenAI API." + image_urls.append(c["image_url"]["url"]) + if len(image_urls) > 1: + raise RuntimeError("Only one image per message is supported") + image_futures = [] + with ThreadPoolExecutor() as executor: + for image_url in image_urls: + fut = executor.submit(_decode_image_without_rgb, image_url) + image_futures.append(fut) + images = [fut.result() for fut in image_futures] + assert len(images) <= 1 + text = " ".join(texts) + img.extend(images) + if images: + res.append( + { + "role": role, + "content": [ + {"type": "image"}, + {"type": "text", "text": text}, + ], + } + ) + else: + res.append({"role": role, "content": text}) + return res, img + + @cache_clean + def chat( + self, + messages: List[Dict], + generate_config: Optional[PytorchGenerateConfig] = None, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + from transformers import TextIteratorStreamer + + if not generate_config: + generate_config = {} + + stream = generate_config.get("stream", False) + msgs, imgs = self._get_processed_msgs(messages) + + inputs = self._tokenizer.apply_chat_template( + msgs, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) # chat mode + inputs = inputs.to(self._model.device) + + generate_kwargs = { + **inputs, + } + if len(imgs) > 0: + generate_kwargs["pixel_values"] = torch.tensor( + self._processor(imgs[-1]).pixel_values + ).to(self._model.device) + stop_str = "<|endoftext|>" + + if stream: + streamer = TextIteratorStreamer( + tokenizer=self._tokenizer, + timeout=60, + skip_prompt=True, + skip_special_tokens=True, + ) + generate_kwargs = { + **generate_kwargs, + "streamer": streamer, + } + t = Thread(target=self._model.generate, kwargs=generate_kwargs) + t.start() + + it = self.chat_stream(streamer, stop_str) + return self._to_chat_completion_chunks(it) + else: + with torch.no_grad(): + outputs = self._model.generate(**generate_kwargs) + outputs = outputs[0][len(inputs["input_ids"][0]) :] + response = self._tokenizer.decode(outputs) + if response.endswith(stop_str): + response = response[: -len(stop_str)] + return generate_chat_completion(self.model_uid, response) + + def chat_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]: + completion_id = str(uuid.uuid1()) + for new_text in streamer: + if not new_text.endswith(stop_str): + yield generate_completion_chunk( + chunk_text=new_text, + finish_reason=None, + chunk_id=completion_id, + model_uid=self.model_uid, + prompt_tokens=-1, + completion_tokens=-1, + total_tokens=-1, + has_choice=True, + has_content=True, + ) + + yield generate_completion_chunk( + chunk_text=None, + finish_reason="stop", + chunk_id=completion_id, + model_uid=self.model_uid, + prompt_tokens=-1, + completion_tokens=-1, + total_tokens=-1, + has_choice=True, + has_content=False, + ) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 55b2b02ae4..e92e2cb3d4 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -569,6 +569,25 @@ def _decode_image(_url): return Image.open(BytesIO(response.content)).convert("RGB") +def _decode_image_without_rgb(_url): + if _url.startswith("data:"): + logging.info("Parse url by base64 decoder.") + # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images + # e.g. f"data:image/jpeg;base64,{base64_image}" + _type, data = _url.split(";") + _, ext = _type.split("/") + data = data[len("base64,") :] + data = base64.b64decode(data.encode("utf-8")) + return Image.open(BytesIO(data)) + else: + try: + response = requests.get(_url) + except requests.exceptions.MissingSchema: + return Image.open(_url) + else: + return Image.open(BytesIO(response.content)) + + @typing.no_type_check def generate_completion_chunk( chunk_text: Optional[str],