Skip to content

Commit

Permalink
FEAT: speculative decoding (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Oct 20, 2023
1 parent 866368b commit 14fc29c
Show file tree
Hide file tree
Showing 20 changed files with 1,129 additions and 19 deletions.
Binary file added doc/source/_static/speculative_decoding.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/_static/speculative_decoding.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/source/user_guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ User Guide
cache_management
backends
client_api
spec_decoding
63 changes: 63 additions & 0 deletions doc/source/user_guide/spec_decoding.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
.. _user_guide_spec_decoding:

===================================
Speculative Decoding (experimental)
===================================

.. image:: ../_static/speculative_decoding.gif

Speculative decoding is a method designed to speed up the inference process of large language models (LLMs). This technique involves using a smaller, quicker "draft" model to produce several tokens in advance. These tokens are then checked by a more extensive "target" model. If the larger model confirms the tokens generated by the draft model, it leads to significant savings in memory bandwidth and processing time per token. However, if the tokens from the draft model don't match the predictions of the larger model, they are discarded.

.. image:: ../_static/speculative_decoding.jpeg
:width: 400

Launching a speculative LLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Using speculative decoding in Xinference is straightforward. The only distinction between speculative decoding and regular decoding is the way to initiate an LLM:

.. code-block:: python
from xinference.client import Client
client = Client("http://localhost:9997")
model_uid = client.launch_speculative_llm(
model_name="wizardcoder-python-v1.0", # target model name
model_size_in_billions=34, # target model size
quantization="none", # target model quantization
draft_model_name="wizardcoder-python-v1.0", # draft model name
draft_model_size_in_billions=7, # draft model size
draft_quantization="none", # draft model quantization
n_gpu=2 # number of GPUs to use
)
.. note::

``Client.launch_speculative_llm`` is an experimental API, which may be removed in the future releases.

After launching the model, you can use it just like a regular model:

.. code-block:: python
model = client.get_model(model_uid)
model.chat(
"""Determine if a 9 x 9 Sudoku board is valid. Only the filled cells need to be validated according to the following rules:
1. Each row must contain the digits 1-9 without repetition.
2. Each column must contain the digits 1-9 without repetition.
3. Each of the nine 3 x 3 sub-boxes of the grid must contain the digits 1-9 without repetition.
Note:
A Sudoku board (partially filled) could be valid but is not necessarily solvable. Only the filled cells need to be validated according to the mentioned rules."""
)
Performance
~~~~~~~~~~~
The effectiveness of speculative decoding relies on:

- The size difference between the models - the larger, the better.
- The similarity between the logits produced by the draft model and the target model.

In the example above, the target model is about five times larger than the draft model, and the two models are well aligned. Approximately 86% of the draft tokens are accepted by the target model, resulting in a 25% increase in speed.

References
~~~~~~~~~~
- [1] `Fast Inference from Transformers via Speculative Decoding <https://arxiv.org/abs/2211.17192>`_
- [2] `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/abs/2302.01318>`_
49 changes: 49 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import uuid
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

import requests
Expand Down Expand Up @@ -352,6 +353,54 @@ def list_models(self) -> Dict[str, Dict[str, Any]]:
response_data = response.json()
return response_data

def launch_speculative_llm(
self,
model_name: str,
model_size_in_billions: Optional[int],
quantization: Optional[str],
draft_model_name: str,
draft_model_size_in_billions: Optional[int],
draft_quantization: Optional[str],
n_gpu: Optional[Union[int, str]] = "auto",
):
"""
Launch the LLM along with a draft model based on the parameters on the server via RESTful APIs. This is an
experimental feature and the API may change in the future.
Returns
-------
str
The unique model_uid for the launched model.
"""
warnings.warn(
"`launch_speculative_llm` is an experimental feature and the API may change in the future."
)

model_uid = self._gen_model_uid()

payload = {
"model_uid": model_uid,
"model_name": model_name,
"model_size_in_billions": model_size_in_billions,
"quantization": quantization,
"draft_model_name": draft_model_name,
"draft_model_size_in_billions": draft_model_size_in_billions,
"draft_quantization": draft_quantization,
"n_gpu": n_gpu,
}

url = f"{self.base_url}/experimental/speculative_llms"
response = requests.post(url, json=payload)
if response.status_code != 200:
raise RuntimeError(
f"Failed to launch model, detail: {response.json()['detail']}"
)

response_data = response.json()
model_uid = response_data["model_uid"]
return model_uid

def launch_model(
self,
model_name: str,
Expand Down
7 changes: 6 additions & 1 deletion xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

logger = logging.getLogger(__name__)

from .utils import log_async

T = TypeVar("T")

Expand Down Expand Up @@ -102,14 +103,15 @@ async def __pre_destroy__(self):
def __init__(self, model: "LLM"):
super().__init__()
from ..model.llm.pytorch.core import PytorchModel
from ..model.llm.pytorch.spec_model import SpeculativeModel
from ..model.llm.vllm.core import VLLMModel

self._model = model

self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
self._lock = (
None
if isinstance(self._model, (PytorchModel, VLLMModel))
if isinstance(self._model, (PytorchModel, SpeculativeModel, VLLMModel))
else asyncio.locks.Lock()
)

Expand Down Expand Up @@ -141,6 +143,7 @@ async def _call_wrapper(self, _wrapper: Callable):
async def _call_async_wrapper(self, _wrapper: Callable):
return await asyncio.create_task(_wrapper())

@log_async(logger=logger)
async def generate(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "generate") and not hasattr(
self._model, "async_generate"
Expand All @@ -163,6 +166,7 @@ async def _async_wrapper():
else:
return await self._call_async_wrapper(_async_wrapper)

@log_async(logger=logger)
async def chat(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "chat") and not hasattr(self._model, "async_chat"):
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
Expand Down Expand Up @@ -215,6 +219,7 @@ async def _wrapper():

return await self._call_wrapper(_wrapper)

@log_async(logger=logger)
async def next(
self, generator_uid: str
) -> Union["ChatCompletionChunk", "CompletionChunk"]:
Expand Down
48 changes: 47 additions & 1 deletion xinference/core/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
)

repetition_penalty_field = Field(
default=1.1,
default=1.0,
ge=0.0,
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
Expand Down Expand Up @@ -262,6 +262,11 @@ def serve(self):
"/v1/models/{model_uid}", self.describe_model, methods=["GET"]
)
self._router.add_api_route("/v1/models", self.launch_model, methods=["POST"])
self._router.add_api_route(
"/experimental/speculative_llms",
self.launch_speculative_llm,
methods=["POST"],
)
self._router.add_api_route(
"/v1/models/{model_uid}", self.terminate_model, methods=["DELETE"]
)
Expand Down Expand Up @@ -385,6 +390,47 @@ async def describe_model(self, model_uid: str) -> Dict[str, Any]:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def launch_speculative_llm(self, request: Request) -> JSONResponse:
payload = await request.json()
model_uid = payload.get("model_uid")
model_name = payload.get("model_name")
model_size_in_billions = payload.get("model_size_in_billions")
quantization = payload.get("quantization")
draft_model_name = payload.get("draft_model_name")
draft_model_size_in_billions = payload.get("draft_model_size_in_billions")
draft_quantization = payload.get("draft_quantization")
n_gpu = payload.get("n_gpu", "auto")

if model_uid is None or model_uid is None:
raise HTTPException(
status_code=400,
detail="Invalid input. Please specify the model UID and the model name",
)

try:
model_uid = await self._supervisor_ref.launch_speculative_llm(
model_uid=model_uid,
model_name=model_name,
model_size_in_billions=model_size_in_billions,
quantization=quantization,
draft_model_name=draft_model_name,
draft_model_size_in_billions=draft_model_size_in_billions,
draft_quantization=draft_quantization,
n_gpu=n_gpu,
)

except ValueError as ve:
logger.error(str(ve), exc_info=True)
raise HTTPException(status_code=400, detail=str(ve))
except RuntimeError as re:
logger.error(str(re), exc_info=True)
raise HTTPException(status_code=503, detail=str(re))
except Exception as e:
logger.error(str(e), exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

return JSONResponse(content={"model_uid": model_uid})

async def launch_model(self, request: Request) -> JSONResponse:
payload = await request.json()
model_uid = payload.get("model_uid")
Expand Down
58 changes: 58 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,64 @@ async def unregister_model(self, model_type: str, model_name: str):
else:
raise ValueError(f"Unsupported model type: {model_type}")

async def launch_speculative_llm(
self,
model_uid: str,
model_name: str,
model_size_in_billions: Optional[int],
quantization: Optional[str],
draft_model_name: str,
draft_model_size_in_billions: Optional[int],
draft_quantization: Optional[str],
n_gpu: Optional[Union[int, str]] = "auto",
) -> AsyncGenerator:
logger.debug(
(
f"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, "
f"draft_model_name: %s, draft_model_size: %s"
),
model_uid,
model_name,
str(model_size_in_billions) if model_size_in_billions else "",
draft_model_name,
draft_model_size_in_billions,
)

# TODO: the draft and target model must be on the same worker.
if not self.is_local_deployment():
raise ValueError(
"Speculative model is not supported in distributed deployment yet."
)

if model_uid in self._model_uid_to_replica_info:
raise ValueError(f"Model is already in the model list, uid: {model_uid}")

worker_ref = await self._choose_worker()
replica = 1
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)

try:
rep_model_uid = f"{model_uid}-{1}-{0}"
yield worker_ref.launch_speculative_model(
model_uid=rep_model_uid,
model_name=model_name,
model_size_in_billions=model_size_in_billions,
quantization=quantization,
draft_model_name=draft_model_name,
draft_model_size_in_billions=draft_model_size_in_billions,
draft_quantization=draft_quantization,
n_gpu=n_gpu,
)
self._replica_model_uid_to_worker[rep_model_uid] = worker_ref

except Exception:
# terminate_model will remove the replica info.
await self.terminate_model(model_uid, suppress_exception=True)
raise
raise xo.Return(model_uid)

async def launch_builtin_model(
self,
model_uid: str,
Expand Down
Loading

0 comments on commit 14fc29c

Please sign in to comment.