diff --git a/doc/source/_static/speculative_decoding.gif b/doc/source/_static/speculative_decoding.gif new file mode 100644 index 0000000000..b341c7f3f3 Binary files /dev/null and b/doc/source/_static/speculative_decoding.gif differ diff --git a/doc/source/_static/speculative_decoding.jpeg b/doc/source/_static/speculative_decoding.jpeg new file mode 100644 index 0000000000..d5b1e46cf1 Binary files /dev/null and b/doc/source/_static/speculative_decoding.jpeg differ diff --git a/doc/source/user_guide/index.rst b/doc/source/user_guide/index.rst index 7ed8f54384..bca65cec4f 100644 --- a/doc/source/user_guide/index.rst +++ b/doc/source/user_guide/index.rst @@ -10,3 +10,4 @@ User Guide cache_management backends client_api + spec_decoding diff --git a/doc/source/user_guide/spec_decoding.rst b/doc/source/user_guide/spec_decoding.rst new file mode 100644 index 0000000000..52f6865863 --- /dev/null +++ b/doc/source/user_guide/spec_decoding.rst @@ -0,0 +1,63 @@ +.. _user_guide_spec_decoding: + +=================================== +Speculative Decoding (experimental) +=================================== + +.. image:: ../_static/speculative_decoding.gif + +Speculative decoding is a method designed to speed up the inference process of large language models (LLMs). This technique involves using a smaller, quicker "draft" model to produce several tokens in advance. These tokens are then checked by a more extensive "target" model. If the larger model confirms the tokens generated by the draft model, it leads to significant savings in memory bandwidth and processing time per token. However, if the tokens from the draft model don't match the predictions of the larger model, they are discarded. + +.. image:: ../_static/speculative_decoding.jpeg + :width: 400 + +Launching a speculative LLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using speculative decoding in Xinference is straightforward. The only distinction between speculative decoding and regular decoding is the way to initiate an LLM: + +.. code-block:: python + + from xinference.client import Client + + client = Client("http://localhost:9997") + model_uid = client.launch_speculative_llm( + model_name="wizardcoder-python-v1.0", # target model name + model_size_in_billions=34, # target model size + quantization="none", # target model quantization + draft_model_name="wizardcoder-python-v1.0", # draft model name + draft_model_size_in_billions=7, # draft model size + draft_quantization="none", # draft model quantization + n_gpu=2 # number of GPUs to use + ) + +.. note:: + + ``Client.launch_speculative_llm`` is an experimental API, which may be removed in the future releases. + +After launching the model, you can use it just like a regular model: + +.. code-block:: python + + model = client.get_model(model_uid) + model.chat( + """Determine if a 9 x 9 Sudoku board is valid. Only the filled cells need to be validated according to the following rules: + 1. Each row must contain the digits 1-9 without repetition. + 2. Each column must contain the digits 1-9 without repetition. + 3. Each of the nine 3 x 3 sub-boxes of the grid must contain the digits 1-9 without repetition. + Note: + A Sudoku board (partially filled) could be valid but is not necessarily solvable. Only the filled cells need to be validated according to the mentioned rules.""" + ) + +Performance +~~~~~~~~~~~ +The effectiveness of speculative decoding relies on: + +- The size difference between the models - the larger, the better. +- The similarity between the logits produced by the draft model and the target model. + +In the example above, the target model is about five times larger than the draft model, and the two models are well aligned. Approximately 86% of the draft tokens are accepted by the target model, resulting in a 25% increase in speed. + +References +~~~~~~~~~~ +- [1] `Fast Inference from Transformers via Speculative Decoding `_ +- [2] `Accelerating Large Language Model Decoding with Speculative Sampling `_ diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 7dbe2f84c0..1a4c5cbf6a 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -13,6 +13,7 @@ # limitations under the License. import uuid +import warnings from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union import requests @@ -352,6 +353,54 @@ def list_models(self) -> Dict[str, Dict[str, Any]]: response_data = response.json() return response_data + def launch_speculative_llm( + self, + model_name: str, + model_size_in_billions: Optional[int], + quantization: Optional[str], + draft_model_name: str, + draft_model_size_in_billions: Optional[int], + draft_quantization: Optional[str], + n_gpu: Optional[Union[int, str]] = "auto", + ): + """ + Launch the LLM along with a draft model based on the parameters on the server via RESTful APIs. This is an + experimental feature and the API may change in the future. + + Returns + ------- + str + The unique model_uid for the launched model. + + """ + warnings.warn( + "`launch_speculative_llm` is an experimental feature and the API may change in the future." + ) + + model_uid = self._gen_model_uid() + + payload = { + "model_uid": model_uid, + "model_name": model_name, + "model_size_in_billions": model_size_in_billions, + "quantization": quantization, + "draft_model_name": draft_model_name, + "draft_model_size_in_billions": draft_model_size_in_billions, + "draft_quantization": draft_quantization, + "n_gpu": n_gpu, + } + + url = f"{self.base_url}/experimental/speculative_llms" + response = requests.post(url, json=payload) + if response.status_code != 200: + raise RuntimeError( + f"Failed to launch model, detail: {response.json()['detail']}" + ) + + response_data = response.json() + model_uid = response_data["model_uid"] + return model_uid + def launch_model( self, model_name: str, diff --git a/xinference/core/model.py b/xinference/core/model.py index 34953816e2..708f3c4182 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) +from .utils import log_async T = TypeVar("T") @@ -102,6 +103,7 @@ async def __pre_destroy__(self): def __init__(self, model: "LLM"): super().__init__() from ..model.llm.pytorch.core import PytorchModel + from ..model.llm.pytorch.spec_model import SpeculativeModel from ..model.llm.vllm.core import VLLMModel self._model = model @@ -109,7 +111,7 @@ def __init__(self, model: "LLM"): self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {} self._lock = ( None - if isinstance(self._model, (PytorchModel, VLLMModel)) + if isinstance(self._model, (PytorchModel, SpeculativeModel, VLLMModel)) else asyncio.locks.Lock() ) @@ -141,6 +143,7 @@ async def _call_wrapper(self, _wrapper: Callable): async def _call_async_wrapper(self, _wrapper: Callable): return await asyncio.create_task(_wrapper()) + @log_async(logger=logger) async def generate(self, prompt: str, *args, **kwargs): if not hasattr(self._model, "generate") and not hasattr( self._model, "async_generate" @@ -163,6 +166,7 @@ async def _async_wrapper(): else: return await self._call_async_wrapper(_async_wrapper) + @log_async(logger=logger) async def chat(self, prompt: str, *args, **kwargs): if not hasattr(self._model, "chat") and not hasattr(self._model, "async_chat"): raise AttributeError(f"Model {self._model.model_spec} is not for chat.") @@ -215,6 +219,7 @@ async def _wrapper(): return await self._call_wrapper(_wrapper) + @log_async(logger=logger) async def next( self, generator_uid: str ) -> Union["ChatCompletionChunk", "CompletionChunk"]: diff --git a/xinference/core/restful_api.py b/xinference/core/restful_api.py index 4f109fd62f..c5ef394116 100644 --- a/xinference/core/restful_api.py +++ b/xinference/core/restful_api.py @@ -75,7 +75,7 @@ ) repetition_penalty_field = Field( - default=1.1, + default=1.0, ge=0.0, description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", @@ -262,6 +262,11 @@ def serve(self): "/v1/models/{model_uid}", self.describe_model, methods=["GET"] ) self._router.add_api_route("/v1/models", self.launch_model, methods=["POST"]) + self._router.add_api_route( + "/experimental/speculative_llms", + self.launch_speculative_llm, + methods=["POST"], + ) self._router.add_api_route( "/v1/models/{model_uid}", self.terminate_model, methods=["DELETE"] ) @@ -385,6 +390,47 @@ async def describe_model(self, model_uid: str) -> Dict[str, Any]: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def launch_speculative_llm(self, request: Request) -> JSONResponse: + payload = await request.json() + model_uid = payload.get("model_uid") + model_name = payload.get("model_name") + model_size_in_billions = payload.get("model_size_in_billions") + quantization = payload.get("quantization") + draft_model_name = payload.get("draft_model_name") + draft_model_size_in_billions = payload.get("draft_model_size_in_billions") + draft_quantization = payload.get("draft_quantization") + n_gpu = payload.get("n_gpu", "auto") + + if model_uid is None or model_uid is None: + raise HTTPException( + status_code=400, + detail="Invalid input. Please specify the model UID and the model name", + ) + + try: + model_uid = await self._supervisor_ref.launch_speculative_llm( + model_uid=model_uid, + model_name=model_name, + model_size_in_billions=model_size_in_billions, + quantization=quantization, + draft_model_name=draft_model_name, + draft_model_size_in_billions=draft_model_size_in_billions, + draft_quantization=draft_quantization, + n_gpu=n_gpu, + ) + + except ValueError as ve: + logger.error(str(ve), exc_info=True) + raise HTTPException(status_code=400, detail=str(ve)) + except RuntimeError as re: + logger.error(str(re), exc_info=True) + raise HTTPException(status_code=503, detail=str(re)) + except Exception as e: + logger.error(str(e), exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse(content={"model_uid": model_uid}) + async def launch_model(self, request: Request) -> JSONResponse: payload = await request.json() model_uid = payload.get("model_uid") diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index f73ef60af3..a163f3fca9 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -201,6 +201,64 @@ async def unregister_model(self, model_type: str, model_name: str): else: raise ValueError(f"Unsupported model type: {model_type}") + async def launch_speculative_llm( + self, + model_uid: str, + model_name: str, + model_size_in_billions: Optional[int], + quantization: Optional[str], + draft_model_name: str, + draft_model_size_in_billions: Optional[int], + draft_quantization: Optional[str], + n_gpu: Optional[Union[int, str]] = "auto", + ) -> AsyncGenerator: + logger.debug( + ( + f"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, " + f"draft_model_name: %s, draft_model_size: %s" + ), + model_uid, + model_name, + str(model_size_in_billions) if model_size_in_billions else "", + draft_model_name, + draft_model_size_in_billions, + ) + + # TODO: the draft and target model must be on the same worker. + if not self.is_local_deployment(): + raise ValueError( + "Speculative model is not supported in distributed deployment yet." + ) + + if model_uid in self._model_uid_to_replica_info: + raise ValueError(f"Model is already in the model list, uid: {model_uid}") + + worker_ref = await self._choose_worker() + replica = 1 + self._model_uid_to_replica_info[model_uid] = ReplicaInfo( + replica=replica, scheduler=itertools.cycle(range(replica)) + ) + + try: + rep_model_uid = f"{model_uid}-{1}-{0}" + yield worker_ref.launch_speculative_model( + model_uid=rep_model_uid, + model_name=model_name, + model_size_in_billions=model_size_in_billions, + quantization=quantization, + draft_model_name=draft_model_name, + draft_model_size_in_billions=draft_model_size_in_billions, + draft_quantization=draft_quantization, + n_gpu=n_gpu, + ) + self._replica_model_uid_to_worker[rep_model_uid] = worker_ref + + except Exception: + # terminate_model will remove the replica info. + await self.terminate_model(model_uid, suppress_exception=True) + raise + raise xo.Return(model_uid) + async def launch_builtin_model( self, model_uid: str, diff --git a/xinference/core/worker.py b/xinference/core/worker.py index bf92419744..d7b2830a71 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -113,12 +113,11 @@ async def _create_subpool( ) return subpool_address, [str(dev) for dev in devices] - def _check_model_is_valid(self, model_name): + def _check_model_is_valid(self, model_name: str, model_format: Optional[str]): # baichuan-base and baichuan-chat depend on `cpm_kernels` module, # but `cpm_kernels` cannot run on Darwin system. - if platform.system() == "Darwin": - # TODO: there's no baichuan-base. - if model_name in ["baichuan-base", "baichuan-chat"]: + if platform.system() == "Darwin" and model_format == "pytorch": + if "baichuan" in model_name: raise ValueError(f"{model_name} model can't run on Darwin system.") @log_sync(logger=logger) @@ -142,6 +141,58 @@ async def unregister_model(self, model_type: str, model_name: str): else: raise ValueError(f"Unsupported model type: {model_type}") + @log_async(logger=logger) + async def launch_speculative_model( + self, + model_uid: str, + model_name: str, + model_size_in_billions: Optional[int], + quantization: Optional[str], + draft_model_name: str, + draft_model_size_in_billions: Optional[int], + draft_quantization: Optional[str], + n_gpu: Optional[Union[int, str]] = "auto", + ): + if n_gpu is not None: + if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > cuda_count()): + raise ValueError( + f"The parameter `n_gpu` must be greater than 0 and " + f"not greater than the number of GPUs: {cuda_count()} on the machine." + ) + if isinstance(n_gpu, str) and n_gpu != "auto": + raise ValueError("Currently `n_gpu` only supports `auto`.") + + from ..model.llm.core import create_speculative_llm_model_instance + + model, model_description = create_speculative_llm_model_instance( + model_uid=model_uid, + model_name=model_name, + model_size_in_billions=model_size_in_billions, + quantization=quantization, + draft_model_name=draft_model_name, + draft_model_size_in_billions=draft_model_size_in_billions, + draft_quantization=draft_quantization, + is_local_deployment=True, + ) + + subpool_address, devices = await self._create_subpool(model_uid, n_gpu=n_gpu) + try: + model_ref = await xo.create_actor( + ModelActor, address=subpool_address, uid=model_uid, model=model + ) + await model_ref.load() + except: + logger.error(f"Failed to load model {model_uid}", exc_info=True) + await self._main_pool.remove_sub_pool(subpool_address) + raise + + self._model_uid_to_model[model_uid] = model_ref + self._model_uid_to_model_spec[model_uid] = model_description + for dev in devices: + self._gpu_to_model_uid[int(dev)] = model_uid + self._model_uid_to_addr[model_uid] = subpool_address + return model_ref + @log_async(logger=logger) async def launch_builtin_model( self, @@ -164,7 +215,7 @@ async def launch_builtin_model( raise ValueError("Currently `n_gpu` only supports `auto`.") assert model_uid not in self._model_uid_to_model - self._check_model_is_valid(model_name) + self._check_model_is_valid(model_name, model_format) assert self._supervisor_ref is not None is_local_deployment = await self._supervisor_ref.is_local_deployment() @@ -186,6 +237,7 @@ async def launch_builtin_model( ) await model_ref.load() except: + logger.error(f"Failed to load model {model_uid}", exc_info=True) await self._main_pool.remove_sub_pool(subpool_address) raise diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 04d17c84c9..1b6fe450f3 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -147,3 +147,67 @@ def create_llm_model_instance( model = llm_cls(model_uid, llm_family, llm_spec, quantization, save_path, kwargs) return model, LLMDescription(llm_family, llm_spec, quantization) + + +def create_speculative_llm_model_instance( + model_uid: str, + model_name: str, + model_size_in_billions: Optional[int], + quantization: Optional[str], + draft_model_name: str, + draft_model_size_in_billions: Optional[int], + draft_quantization: Optional[str], + is_local_deployment: bool = False, +) -> Tuple[LLM, LLMDescription]: + from . import match_llm + from .llm_family import cache + + match_result = match_llm( + model_name, + "pytorch", + model_size_in_billions, + quantization, + is_local_deployment, + ) + + if not match_result: + raise ValueError( + f"Model not found, name: {model_name}, format: pytorch," + f" size: {model_size_in_billions}, quantization: {quantization}" + ) + llm_family, llm_spec, quantization = match_result + assert quantization is not None + save_path = cache(llm_family, llm_spec, quantization) + + draft_match_result = match_llm( + draft_model_name, + "pytorch", + draft_model_size_in_billions, + draft_quantization, + is_local_deployment, + ) + + if not draft_match_result: + raise ValueError( + f"Model not found, name: {draft_model_name}, format: pytorch," + f" size: {draft_model_size_in_billions}, quantization: {draft_quantization}" + ) + draft_llm_family, draft_llm_spec, draft_quantization = draft_match_result + assert draft_quantization is not None + draft_save_path = cache(draft_llm_family, draft_llm_spec, draft_quantization) + + from .pytorch.spec_model import SpeculativeModel + + model = SpeculativeModel( + model_uid, + model_family=llm_family, + model_spec=llm_spec, + quantization=quantization, + model_path=save_path, + draft_model_family=draft_llm_family, + draft_model_spec=draft_llm_spec, + draft_quantization=draft_quantization, + draft_model_path=draft_save_path, + ) + + return model, LLMDescription(llm_family, llm_spec, quantization) diff --git a/xinference/model/llm/pytorch/baichuan.py b/xinference/model/llm/pytorch/baichuan.py index a08234216d..706410d51c 100644 --- a/xinference/model/llm/pytorch/baichuan.py +++ b/xinference/model/llm/pytorch/baichuan.py @@ -38,7 +38,7 @@ def __init__( ) self._use_fast_tokenizer = False - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig diff --git a/xinference/model/llm/pytorch/chatglm.py b/xinference/model/llm/pytorch/chatglm.py index b78407e9d7..b22524f864 100644 --- a/xinference/model/llm/pytorch/chatglm.py +++ b/xinference/model/llm/pytorch/chatglm.py @@ -37,7 +37,7 @@ def __init__( pytorch_model_config=pytorch_model_config, ) - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoModel, AutoTokenizer except ImportError: diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 87733fa89f..05a5d64d71 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -78,7 +78,7 @@ def _sanitize_generate_config( pytorch_generate_config["model"] = self.model_uid return pytorch_generate_config - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: @@ -177,7 +177,7 @@ def load(self): if num_gpus > 0 and self._device == "cuda": kwargs.update({"device_map": "auto"}) - self._model, self._tokenizer = self._load_model(kwargs) + self._model, self._tokenizer = self._load_model(**kwargs) if self._device == "mps": self._model.to(self._device) @@ -402,13 +402,13 @@ def _sanitize_generate_config( pytorch_generate_config ) if ( - "stop" not in pytorch_generate_config + pytorch_generate_config.get("stop", None) is None and self.model_family.prompt_style and self.model_family.prompt_style.stop ): pytorch_generate_config["stop"] = self.model_family.prompt_style.stop.copy() if ( - "stop_token_ids" not in pytorch_generate_config + pytorch_generate_config.get("stop_token_ids", None) is None and self.model_family.prompt_style and self.model_family.prompt_style.stop_token_ids ): diff --git a/xinference/model/llm/pytorch/falcon.py b/xinference/model/llm/pytorch/falcon.py index b26de065dc..104eaaa4bf 100644 --- a/xinference/model/llm/pytorch/falcon.py +++ b/xinference/model/llm/pytorch/falcon.py @@ -37,7 +37,7 @@ def __init__( pytorch_model_config=pytorch_model_config, ) - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: @@ -94,7 +94,7 @@ def __init__( pytorch_model_config=pytorch_model_config, ) - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: diff --git a/xinference/model/llm/pytorch/llama_2.py b/xinference/model/llm/pytorch/llama_2.py index b8223e07aa..2c9bae04cf 100644 --- a/xinference/model/llm/pytorch/llama_2.py +++ b/xinference/model/llm/pytorch/llama_2.py @@ -37,8 +37,8 @@ def __init__( pytorch_model_config=pytorch_model_config, ) - def _load_model(self, kwargs: dict): - model, tokenizer = super()._load_model(kwargs) + def _load_model(self, **kwargs): + model, tokenizer = super()._load_model(**kwargs) # Llama has no pad token by default # https://github.com/huggingface/transformers/blob/07998ef39926b76d3f6667025535d0859eed61c3/docs/source/en/llm_tutorial.md?plain=1#L125 tokenizer.pad_token = tokenizer.eos_token @@ -79,8 +79,8 @@ def __init__( ) self._use_fast_tokenizer = False - def _load_model(self, kwargs: dict): - model, tokenizer = super()._load_model(kwargs) + def _load_model(self, **kwargs): + model, tokenizer = super()._load_model(**kwargs) # Llama has no pad token by default # https://github.com/huggingface/transformers/blob/07998ef39926b76d3f6667025535d0859eed61c3/docs/source/en/llm_tutorial.md?plain=1#L125 tokenizer.pad_token = tokenizer.eos_token diff --git a/xinference/model/llm/pytorch/spec_decoding_utils.py b/xinference/model/llm/pytorch/spec_decoding_utils.py new file mode 100644 index 0000000000..5e8bd1590c --- /dev/null +++ b/xinference/model/llm/pytorch/spec_decoding_utils.py @@ -0,0 +1,528 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import logging +import time +import uuid +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple + +try: + import torch + from torch.nn import functional as F +except ImportError: + raise ImportError( + f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" + ) + +try: + from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers.generation.logits_process import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) +except ImportError: + error_message = "Failed to import module 'transformers'" + installation_guide = [ + "Please make sure 'transformers' is installed. ", + "You can install it by `pip install transformers`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + +from ....types import CompletionChoice, CompletionChunk, CompletionUsage + +logger = logging.getLogger(__name__) + + +def prepare_logits_processor( + temperature: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + if ( + hasattr(config, "max_sequence_length") + and config.max_sequence_length is not None + ): + return config.max_sequence_length + elif hasattr(config, "seq_length") and config.seq_length is not None: + return config.seq_length + elif ( + hasattr(config, "max_position_embeddings") + and config.max_position_embeddings is not None + ): + return config.max_position_embeddings + else: + return 2048 + + +def normalize_logits( + logits_processor: LogitsProcessorList, + input_ids: List[int], + logits: torch.FloatTensor, # [1, n_seq, n_vocab] +) -> torch.Tensor: + """ + Parameters + ---------- + logits : torch.Tensor + Logits of shape `(n_batch, n_seq, n_vocab)`. + + Returns + ------- + torch.Tensor + Normalized logits of shape `(n_batch, n_seq, n_vocab)`. + """ + + def _helper( + _input_ids: torch.LongTensor, _logits: torch.FloatTensor # [1, n_vocab] + ) -> torch.Tensor: + if logits_processor: + last_token_logits = logits_processor( + _input_ids, + _logits, + )[0] + else: + return _logits[0] + + return last_token_logits # [n_vocab,] + + input_ids = torch.as_tensor([input_ids], device=logits.device).long() + for i in range(logits.shape[1]): + normalized = _helper( + input_ids[ + : -logits.shape[1] + i + ], # input_ids may not equal logits.shape[1] + logits[:, i, :], + ) + logits[:, i, :] = normalized.clone() + return F.softmax(logits, dim=-1) + + +def sample( + last_token_logits: torch.FloatTensor, temperature: float, top_p: float +) -> int: + """ + Parameters + ---------- + last_token_logits : torch.FloatTensor + Last token logits of shape [n_vocab,] + + Returns + ------- + int + Token ID. + """ + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + indices = torch.multinomial(last_token_logits, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + return tokens[0] + + +def rollback_kv_cache( + kv_cache: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], n: int +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + ret = [] + for k_cache, v_cache in kv_cache: + k_cache = k_cache[:, :, :-n, :] # [1, n_head, n_seq - n, n_dim] + v_cache = v_cache[:, :, :-n, :] + + assert isinstance(k_cache, torch.Tensor) + assert isinstance(v_cache, torch.Tensor) + ret.append((k_cache, v_cache)) + + return tuple(ret) + + +def rollback_logits(logits: torch.Tensor, n: int): + return logits[:, :-n, :] # [1, n_seq, n_vocab] + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +def draft( + input_ids: List[int], + kv_cache: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]], + logits: Optional[torch.FloatTensor], + draft_model: "PreTrainedModel", + gamma: int, + logits_processor: LogitsProcessorList, + temperature: float, + top_p: float, +): + """ + Parameters + ---------- + input_ids : List[int] + On the prefill stage, `input_ids` are the prompt tokens. + + On the decode stage. It includes the prompt tokens, the token generated by the original model + at the end of each full iteration, or the token generated by the draft model draft + iteration. + + Returns + ------- + int + The number of generated draft tokens. + List[int] + Outputs, including the draft tokens. + Tuple[Tuple[torch.Tensor, torch.Tensor], ...] + KV cache. + torch.FloatTensor + Logits. + """ + draft_output_ids = input_ids.copy() + + if kv_cache is not None: + input_ids = draft_output_ids[-2:] + + num_draft_tokens = 0 + while num_draft_tokens < gamma: + if kv_cache is None: + # prefill. + draft_model_out = draft_model( + torch.as_tensor([input_ids], device=draft_model.device), + use_cache=True, + ) + logits = normalize_logits( + logits_processor, input_ids, draft_model_out.logits + ) + else: + draft_model_out = draft_model( + torch.as_tensor([input_ids], device=draft_model.device), + use_cache=True, + past_key_values=kv_cache, + ) + normalized = normalize_logits( + logits_processor, draft_output_ids, draft_model_out.logits + ) + assert logits is not None + logits = torch.cat((logits, normalized), dim=1) + kv_cache = draft_model_out.past_key_values + draft_token = sample( + logits[0, -1, :], + temperature, + top_p, + ) + draft_output_ids.append(draft_token) + input_ids = [draft_token] + num_draft_tokens += 1 + + assert kv_cache is not None + return num_draft_tokens, draft_output_ids, kv_cache, logits + + +@torch.inference_mode() +def speculative_generate_stream( + draft_model: "PreTrainedModel", + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + prompt: str, + generate_config: Dict[str, Any], +) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]: + logger.debug( + f"Enter speculative_generate_stream, prompt: {prompt}, generate_config: {generate_config}" + ) + + # TODO: currently, repetition penalty leads to garbled outputs. + if float(generate_config.get("repetition_penalty", 1.0)) != 1.0: + raise ValueError( + "repetition penalty is not supported by speculative decoding yet" + ) + + gamma = generate_config.get("gamma", 4) + stream = generate_config.get("stream", False) + temperature = float(generate_config.get("temperature", 1.0)) + top_p = float(generate_config.get("top_p", 1.0)) + top_k = int(generate_config.get("top_k", -1)) # -1 means disable + max_new_tokens = int(generate_config.get("max_tokens", 256)) + echo = bool(generate_config.get("echo", False)) + stop_str = generate_config.get("stop", None) + stop_token_ids = generate_config.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor(temperature, top_p, top_k) + request_id = str(uuid.uuid1()) + + if "qwen" in str(type(model)).lower(): + # TODO: hacky. + input_ids = tokenizer(prompt, allowed_special="all").input_ids + else: + input_ids = tokenizer(prompt).input_ids + + num_prompt_tokens = len(input_ids) + output_ids = list(input_ids) + + # internal states. + draft_kv_cache = None + draft_logits = None + kv_cache = None + logits = None + next_token = ( + None # the token generated by the original model at each full iteration. + ) + last_output_length = 0 + finish_reason = "stop" + + # performance stats. + total_seconds_on_drafting = 0.0 + total_seconds_on_eval = 0.0 + total_seconds_on_accepting = 0.0 + total_num_draft_tokens = 0 + total_num_accepted_tokens = 0 + + while len(output_ids) < max_new_tokens + num_prompt_tokens: + # allow the draft model to generate more than max_tokens since some of the generated + # tokens could be rejected. + start = time.time() + num_draft_tokens, output_ids, draft_kv_cache, draft_logits = draft( + input_ids=output_ids, + kv_cache=draft_kv_cache, + logits=draft_logits, + draft_model=draft_model, + gamma=gamma, + logits_processor=logits_processor, + temperature=temperature + * 0.5, # make the draft model outputs less random for better quality. + top_p=top_p, + ) + total_seconds_on_drafting += time.time() - start + total_num_draft_tokens += num_draft_tokens + + # eval stage. + start = time.time() + if kv_cache is None: + # prefill. + out = model( + torch.as_tensor([output_ids], device=model.device), use_cache=True + ) + logits = normalize_logits(logits_processor, output_ids, out.logits) + else: + out = model( + torch.as_tensor( + [[next_token] + output_ids[-num_draft_tokens:]], device=model.device + ), + use_cache=True, + past_key_values=kv_cache, + ) + normalized = normalize_logits(logits_processor, output_ids, out.logits) + logits = torch.cat((logits, normalized), dim=1) + kv_cache = out.past_key_values + total_seconds_on_eval += time.time() - start + + # accepting stage. + start = time.time() + assert draft_logits is not None + assert draft_kv_cache is not None + accepted = 0 + stopped = False + for draft_token_idx in range(-num_draft_tokens, 0): + r = torch.rand(1, device=logits.device) + draft_token = output_ids[draft_token_idx] + token_logits = logits[:, draft_token_idx - 1, :] # [1, n_vocab,] + draft_token_logits = draft_logits[:, draft_token_idx, :].to( + logits.device + ) # [1, n_vocab,] + if token_logits[0, draft_token] / draft_token_logits[0, draft_token] > r: + accepted += 1 + total_num_accepted_tokens += 1 + if draft_token in stop_token_ids: + stopped = True + else: + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug( + f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens: draft_token_idx])}'" + ) + logger.debug( + f"Rejected: '{tokenizer.decode(output_ids[draft_token_idx:])}'" + ) + # rollback. + output_ids = output_ids[:draft_token_idx] + draft_kv_cache = rollback_kv_cache( + draft_kv_cache, num_draft_tokens - accepted + ) + kv_cache = rollback_kv_cache(kv_cache, num_draft_tokens - accepted) + draft_logits = rollback_logits( + draft_logits, num_draft_tokens - accepted + ) + logits = rollback_logits(logits, num_draft_tokens - accepted) + + # sample the next token according to the modified distribution of shape [1, n_vocab] + modified_dist = token_logits - draft_token_logits + modified_dist = torch.where( + modified_dist > 0, modified_dist, torch.zeros_like(modified_dist) + ) + normalized = normalize_logits( + logits_processor, + output_ids, + modified_dist.unsqueeze(1), # [1, 1, n_vocab] + ) + next_token = sample( + normalized[0, -1, :], + 0, # must be 0, since the dist is quiet unified, higher temperature results in garbled text + top_p, + ) + output_ids.append(next_token) + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug(f"Generated: '{tokenizer.decode([next_token])}'") + if next_token in stop_token_ids: + stopped = True + break + + if accepted == num_draft_tokens: + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug( + f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens:])}'" + ) + next_token = sample( + logits[0, -1, :], + temperature, + top_p, + ) + output_ids.append(next_token) + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug(f"Generated: '{tokenizer.decode([next_token])}'") + if next_token in stop_token_ids: + stopped = True + + total_seconds_on_accepting += time.time() - start + + if ( + accepted > 0 # more than 2 tokens has been generated, flush. + or len(output_ids) >= max_new_tokens + or stopped + ): + output = tokenizer.decode( + output_ids if echo else output_ids[num_prompt_tokens:], + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + rfind_start = len(prompt) if echo else 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError(f"Invalid stop field type {type(stop_str)}") + + if stream: + # return the delta. + output_length = len(output) + output = output[last_output_length:] + last_output_length = output_length + + # prevent yielding partial stop sequence. + if not partially_stopped: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=None + ) + completion_chunk = CompletionChunk( + id=request_id, + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=num_prompt_tokens, + completion_tokens=len(output_ids) - num_prompt_tokens, + total_tokens=len(output_ids), + ) + + yield completion_chunk, completion_usage + if stopped: + break + else: + finish_reason = "length" + + logger.info( + f"In total, {total_num_accepted_tokens}/{total_num_draft_tokens} draft tokens are " + f"accepted, acceptance rate: {total_num_accepted_tokens / total_num_draft_tokens:.2f}" + ) + total_seconds = ( + total_seconds_on_drafting + total_seconds_on_eval + total_seconds_on_accepting + ) + logger.info( + f"In total, {total_seconds_on_drafting:.2f}s, {total_seconds_on_eval:.2f}s and " + f"{total_seconds_on_accepting:.2f}s are spent on drafting, eval, and accepting " + f"respectively. Average generation speed: {(len(output_ids) - num_prompt_tokens) / total_seconds:.2f} tokens/s." + ) + + if stream: + completion_choice = CompletionChoice( + text="", index=0, logprobs=None, finish_reason=finish_reason + ) + else: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=finish_reason + ) + + completion_chunk = CompletionChunk( + id=request_id, + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=num_prompt_tokens, + completion_tokens=len(output_ids) - num_prompt_tokens, + total_tokens=len(output_ids), + ) + + yield completion_chunk, completion_usage + + # clean up. + del kv_cache + del draft_kv_cache + gc.collect() + torch.cuda.empty_cache() diff --git a/xinference/model/llm/pytorch/spec_model.py b/xinference/model/llm/pytorch/spec_model.py new file mode 100644 index 0000000000..aedd6af2a9 --- /dev/null +++ b/xinference/model/llm/pytorch/spec_model.py @@ -0,0 +1,180 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Iterator, List, Optional, Union + +from ....types import Completion, CompletionChunk, Embedding +from .. import LLMFamilyV1, LLMSpecV1 +from .core import PytorchChatModel, PytorchGenerateConfig, PytorchModelConfig + +logger = logging.getLogger(__name__) + + +class SpeculativeModel(PytorchChatModel): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + draft_model_family: "LLMFamilyV1", + draft_model_spec: "LLMSpecV1", + draft_quantization: str, + draft_model_path: str, + ): + super().__init__(model_uid, model_family, model_spec, quantization, model_path) + self._pytorch_model_config: PytorchModelConfig = self._sanitize_model_config( + PytorchModelConfig() + ) + self._draft_model_family = draft_model_family + self._draft_model_spec = draft_model_spec + self._draft_quantization = draft_quantization + self._draft_model_path = draft_model_path + + def _load_model(self, model_path, **kwargs): + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + error_message = "Failed to import module 'transformers'" + installation_guide = [ + "Please make sure 'transformers' is installed. ", + "You can install it by `pip install transformers`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self._use_fast_tokenizer, + trust_remote_code=kwargs["trust_remote_code"], + revision=kwargs["revision"], + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs, + ) + return model, tokenizer + + def load(self): + try: + import torch + except ImportError: + raise ImportError( + f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" + ) + + cuda_visible_devices_env = os.getenv("CUDA_VISIBLE_DEVICES", None) + cuda_visible_devices = ( + cuda_visible_devices_env.split(",") if cuda_visible_devices_env else [] + ) + + num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0 + device = self._pytorch_model_config.get("device", "auto") + self._pytorch_model_config["device"] = self._select_device(device) + self._device = self._pytorch_model_config["device"] + + if self._device == "cpu": + kwargs = {"torch_dtype": torch.float32} + elif self._device == "cuda": + kwargs = {"torch_dtype": torch.float16} + elif self._device == "mps": + kwargs = {"torch_dtype": torch.float16} + else: + raise ValueError(f"Device {self._device} is not supported in temporary") + kwargs["trust_remote_code"] = self._pytorch_model_config.get( + "trust_remote_code" + ) + + if self.quantization != "none": + raise ValueError( + "Quantization is not supported by speculative decoding yet" + ) + + if num_gpus > 0 and self._device == "cuda": + kwargs.update({"device_map": "auto"}) + + self._model, self._tokenizer = self._load_model( + model_path=self.model_path, + revision=self.model_spec.model_revision, + **kwargs, + ) + if self._device == "mps": + self._model.to(self._device) + logger.debug( + f"Model {self.model_uid} memory footprint: {self._model.get_memory_footprint()}" + ) + + self._draft_model, _ = self._load_model( + model_path=self._draft_model_path, + revision=self._draft_model_spec.model_revision, + **kwargs, + ) + if self._device == "mps": + self._model.to(self._device) + logger.debug( + f"Draft model {self.model_uid} memory footprint: {self._model.get_memory_footprint()}" + ) + + def generate( + self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None + ) -> Union[Completion, Iterator[CompletionChunk]]: + def generator_wrapper( + _prompt: str, _generate_config: PytorchGenerateConfig + ) -> Iterator[CompletionChunk]: + for _completion_chunk, _completion_usage in speculative_generate_stream( + draft_model=self._draft_model, + model=self._model, + tokenizer=self._tokenizer, + prompt=_prompt, + generate_config=_generate_config, + ): + yield _completion_chunk + + from .spec_decoding_utils import speculative_generate_stream + + generate_config = self._sanitize_generate_config(generate_config) + + assert self._draft_model is not None + assert self._model is not None + assert self._tokenizer is not None + + stream = generate_config.get("stream", False) + if not stream: + for completion_chunk, completion_usage in speculative_generate_stream( + draft_model=self._draft_model, + model=self._model, + tokenizer=self._tokenizer, + prompt=prompt, + generate_config=generate_config, + ): + pass + + completion = Completion( + id=completion_chunk["id"], + object=completion_chunk["object"], + created=completion_chunk["created"], + model=completion_chunk["model"], + choices=completion_chunk["choices"], + usage=completion_usage, + ) + return completion + else: + return generator_wrapper(prompt, generate_config) + + def create_embedding(self, input: Union[str, List[str]]) -> Embedding: + raise NotImplementedError diff --git a/xinference/model/llm/pytorch/tests/test_spec_decoding.py b/xinference/model/llm/pytorch/tests/test_spec_decoding.py new file mode 100644 index 0000000000..483de448f6 --- /dev/null +++ b/xinference/model/llm/pytorch/tests/test_spec_decoding.py @@ -0,0 +1,56 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ..spec_decoding_utils import speculative_generate_stream + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.mark.skip(reason="Temporary disabled") +def test_spec_decoding(): + """ + Use the draft model itself as the target model. If the decoding works, all the draft tokens + should be accepted, and the result of speculative decoding should be the same as the regular + decoding, which starts with "The largest animal ever recorded is the Tyrannosaurus Rex". + """ + + model_id = "PY007/TinyLlama-1.1B-Chat-v0.3" + draft_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "What is the largest animal?" + formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + for completion_chunk, completion_usage in speculative_generate_stream( + draft_model=draft_model, + model=draft_model, + tokenizer=tokenizer, + prompt=formatted_prompt, + generate_config={"model": "test", "temperature": 0, "max_tokens": 64}, + ): + pass + + completion = completion_chunk["choices"][0]["text"] + assert completion.startswith( + "The largest animal ever recorded is the Tyrannosaurus Rex" + ) diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index 3dd03f13f8..d2837c1324 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import gc +import logging import re import time import uuid @@ -32,6 +33,8 @@ from ....types import CompletionChoice, CompletionChunk, CompletionUsage +logger = logging.getLogger(__name__) + def is_sentence_complete(output: str): """Check whether the output is a complete sentence.""" @@ -135,6 +138,7 @@ def generate_stream( device=device, ) + start = time.time() past_key_values = out = None sent_interrupt = False token = None @@ -281,6 +285,9 @@ def generate_stream( if stopped: break + elapsed_time = time.time() - start + logger.info(f"Average generation speed: {i / elapsed_time:.2f} tokens/s.") + # finish stream event, which contains finish reason if stopped: finish_reason = "stop" diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 772e043f2b..4713f45492 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -121,6 +121,7 @@ def _sanitize_model_config( model_config.setdefault("block_size", 16) model_config.setdefault("swap_space", 4) model_config.setdefault("gpu_memory_utilization", 0.90) + # TODO: remove model_config.setdefault("max_num_batched_tokens", 2560) model_config.setdefault("max_num_seqs", 256)