-
Notifications
You must be signed in to change notification settings - Fork 45
On Device Sampling #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
On Device Sampling #350
Changes from all commits
718d763
b8d099e
544c0dd
0b4d0a9
24efc93
2af43c6
3eca771
b0e9162
0486e42
e7dda72
f94c657
fa026a4
eff2007
ebfbaea
83d33ac
fc3dc82
abbaf53
f5f5e2d
05c0bf0
3b63ecb
0b6873c
1691a08
02389f8
7dfdda4
d48d084
bf367a6
aa7206d
14eefb9
813a644
929f51f
f40f8c0
f9f4ac9
19f8d49
e328d8e
9f2c061
c59e1ab
40f176e
52a077f
0858e4d
49dc0f3
b5245de
1a0d2ca
6a4e970
411900e
de2c209
bcdb2f0
ee850b2
a58e9af
485330c
0a079cd
413195c
f987bde
d46f35f
c20a49d
4923ba6
1f96a9c
68313a3
d87a99d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import warnings | ||
from pathlib import Path | ||
from time import perf_counter | ||
from typing import List, Optional, Union | ||
from typing import Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -38,6 +38,7 @@ | |
CustomOpsTransform, | ||
KVCacheModuleMethodMapperTransform, | ||
KVCacheTransform, | ||
SamplerTransform, | ||
SpDTransform, | ||
VlmKVOffloadTransform, | ||
VlmNoKVOffloadTransform, | ||
|
@@ -1281,8 +1282,17 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): | |
``Mandatory`` Args: | ||
:model (nn.Module): PyTorch model | ||
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. | ||
:qaic_config (Optional[dict]): Qaic config with supported keys of `speculative_model_type` to specify speculative decoding TLM models. | ||
|
||
``Optional`` Args: | ||
:qaic_config (dict): QAIC config dictionary with the following supported keys: | ||
:speculative_model_type (str): To specify Speculative Decoding Target Language Models. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens. | ||
:return_pdfs (bool): Return probability distributions along with sampled | ||
next tokens. For Speculative Decoding Target Language Model, | ||
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative | ||
Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
:max_top_k_ids (int): Specify the maximum number of top K tokens | ||
(<= vocab size) to consider during sampling. The values provided in | ||
`top_ks` tensor must be less than this maximum limit. | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -1336,10 +1346,20 @@ def __init__( | |
self.model.config.use_cache = True | ||
self.num_layers = model.config.num_hidden_layers | ||
self.continuous_batching = continuous_batching | ||
self.model.qaic_config = qaic_config | ||
|
||
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) | ||
self.is_tlm = transformed | ||
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) | ||
|
||
# ---Sampling--- | ||
# Note: SamplerTransform should be applied after all other transforms | ||
# are done. The role of the sampler is to just add nodes at the output of the | ||
# previous transform function. | ||
self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) | ||
if self.is_tlm: | ||
self.model.qaic_config["return_pdfs"] = True | ||
|
||
@property | ||
def model_name(self) -> str: | ||
mname = self.model.__class__.__name__ | ||
|
@@ -1368,9 +1388,17 @@ def from_pretrained( | |
Args: | ||
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. | ||
:continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. | ||
|
||
:qaic_config (Optional[dict]): Qaic config with supported keys of `speculative_model_type` to specify speculative decoding TLM models. | ||
:args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. | ||
``Optional`` Args: | ||
:qaic_config (dict): QAIC config dictionary with the following supported keys: | ||
:speculative_model_type (str): To specify Speculative Decoding Target Language Models. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens. | ||
:return_pdfs (bool): Return probability distributions along with sampled | ||
next tokens. For Speculative Decoding Target Language Model, | ||
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative | ||
Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
:max_top_k_ids (int): Specify the maximum number of top K tokens | ||
(<= vocab size) to consider during sampling. The values provided in | ||
`top_ks` tensor must be less than this maximum limit. | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -1428,6 +1456,7 @@ def model_hash(self) -> str: | |
mhash.update(to_hashable(self.model.config.to_diff_dict())) | ||
mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) | ||
mhash.update(to_hashable({"is_tlm": self.is_tlm})) | ||
quic-sanising marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mhash.update(to_hashable({"qaic_config": self.model.qaic_config})) | ||
mhash.update(to_hashable(self._transform_names())) | ||
mhash.update(to_hashable(self.pretrained_model_name_or_path)) | ||
mhash = mhash.hexdigest()[:16] | ||
|
@@ -1449,7 +1478,7 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
""" | ||
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE | ||
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN | ||
fbs = constants.ONNX_EXPORT_EXAMPLE_FBS | ||
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS | ||
kv_cache_shape = get_padding_shape_from_config( | ||
self.model.config, fbs if self.continuous_batching else bs, seq_len | ||
) | ||
|
@@ -1472,7 +1501,13 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
2: "ctx_len", | ||
} | ||
output_names = ["logits"] | ||
output_names = [] | ||
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): | ||
if self.model.qaic_config.get("return_pdfs", False): | ||
output_names.append("probs") | ||
output_names.append("next_tokens") | ||
else: | ||
output_names.append("logits") | ||
quic-sanising marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for i in range(self.num_layers): | ||
for kv in ["key", "value"]: | ||
|
@@ -1489,13 +1524,84 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) | ||
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} | ||
|
||
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): | ||
example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( | ||
example_inputs=example_inputs, | ||
output_names=output_names, | ||
dynamic_axes=dynamic_axes, | ||
) | ||
|
||
return self._export( | ||
example_inputs, | ||
output_names, | ||
dynamic_axes, | ||
export_dir=export_dir, | ||
) | ||
|
||
def get_sampling_inputs_and_outputs( | ||
self, | ||
example_inputs: Dict[str, torch.Tensor], | ||
output_names: List[str], | ||
dynamic_axes: Dict[str, Dict[int, str]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a docstring to enhance clarity for other developers? Additionally, could we relocate this to the sampler folder? This will help keep modeling_auto.py streamlined and avoid unnecessary complexity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the doc-string. Cannot to relocate this function to sampler as it is a member function of QEFFAutoModelForCausalLM. |
||
): | ||
""" | ||
Update the example inputs and outputs with respect to the On Device Sampler | ||
for the ONNX export. | ||
""" | ||
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE | ||
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS | ||
|
||
example_inputs["last_accepted_output_tokens"] = torch.zeros( | ||
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 | ||
) | ||
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} | ||
|
||
example_inputs["past_repetition_penalty_buffer"] = torch.zeros( | ||
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool | ||
) | ||
dynamic_axes["past_repetition_penalty_buffer"] = { | ||
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
} | ||
output_names.append("past_repetition_penalty_buffer_RetainedState") | ||
|
||
example_inputs["repetition_penalties"] = ( | ||
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES | ||
) | ||
dynamic_axes["repetition_penalties"] = {0: "batch_size"} | ||
|
||
example_inputs["past_presence_penalty_buffer"] = torch.zeros( | ||
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool | ||
) | ||
dynamic_axes["past_presence_penalty_buffer"] = { | ||
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
} | ||
output_names.append("past_presence_penalty_buffer_RetainedState") | ||
|
||
example_inputs["presence_penalties"] = ( | ||
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES | ||
) | ||
dynamic_axes["presence_penalties"] = {0: "batch_size"} | ||
|
||
example_inputs["temperatures"] = ( | ||
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES | ||
) | ||
dynamic_axes["temperatures"] = {0: "batch_size"} | ||
|
||
max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) | ||
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) | ||
dynamic_axes["top_ks"] = {0: "batch_size"} | ||
|
||
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS | ||
dynamic_axes["top_ps"] = {0: "batch_size"} | ||
|
||
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS | ||
dynamic_axes["min_ps"] = {0: "batch_size"} | ||
|
||
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) | ||
dynamic_axes["random_numbers"] = {0: "batch_size"} | ||
|
||
return example_inputs, output_names, dynamic_axes | ||
|
||
def build_prefill_specialization( | ||
self, | ||
prefill_seq_len: int = 32, | ||
|
@@ -1608,6 +1714,14 @@ def compile( | |
"enable `continuous_batching=True` in `from_pretrained`." | ||
) | ||
|
||
if ( | ||
self.model.qaic_config is not None | ||
and self.model.qaic_config.get("include_sampler", False) | ||
and num_speculative_tokens is not None | ||
and num_speculative_tokens > 0 | ||
): | ||
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") | ||
|
||
# Infer kv_cache_batch_size if not provided | ||
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size | ||
|
||
|
@@ -1617,12 +1731,21 @@ def compile( | |
if prefill_only is None or prefill_only or prefill_seq_len == 1: | ||
specializations.append( | ||
self.build_prefill_specialization( | ||
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
batch_size=batch_size, | ||
kv_cache_batch_size=kv_cache_batch_size, | ||
full_batch_size=full_batch_size, | ||
) | ||
) | ||
if prefill_only is None or not prefill_only: | ||
decode_spec = self.build_decode_specialization( | ||
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
batch_size=batch_size, | ||
kv_cache_batch_size=kv_cache_batch_size, | ||
full_batch_size=full_batch_size, | ||
num_speculative_tokens=num_speculative_tokens, | ||
) | ||
if decode_spec: | ||
specializations.append(decode_spec) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ---------------------------------------------------------------------------- |
Uh oh!
There was an error while loading. Please reload this page.