Skip to content

On Device Sampling #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 58 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
718d763
Initial commit
quic-sanising Apr 8, 2025
b8d099e
Reformat code
quic-sanising Apr 8, 2025
544c0dd
Fix bug
quic-sanising Apr 8, 2025
0b4d0a9
Add Gumbel-Max trick based random sampling
quic-sanising Apr 8, 2025
24efc93
Bring up to date
quic-sanising Apr 8, 2025
2af43c6
Use Gumbel-Max Trick based Random Sampling as default
quic-sanising Apr 8, 2025
3eca771
Clip k to max value
quic-sanising Apr 8, 2025
b0e9162
Add docstring for sampling parameters
quic-sanising Apr 8, 2025
0486e42
Fix bug
quic-sanising Apr 8, 2025
e7dda72
Add support for continuous batching
quic-sanising Apr 8, 2025
f94c657
Fix ONNX error for batch_size 1 treated as a Constant
quic-sanising Apr 8, 2025
fa026a4
Undo docstring deletion
quic-sanising Apr 8, 2025
eff2007
Remove device and unncessary reshapes
quic-sanising Apr 8, 2025
ebfbaea
Revert batch_size to 1
quic-sanising Apr 8, 2025
83d33ac
Remove vocab_size from dynamic axes
quic-sanising Apr 8, 2025
fc3dc82
Change condition
quic-sanising Apr 8, 2025
abbaf53
Change size of each sampling parameter to (batch_size, 1)
quic-sanising Apr 8, 2025
f5f5e2d
Reformat code
quic-sanising Apr 8, 2025
05c0bf0
Fix bug
quic-sanising Apr 8, 2025
3b63ecb
Allow chunked prompts during prefill
quic-sanising Apr 8, 2025
0b6873c
Merge remote-tracking branch 'upstream/main' into on-device-sampling
quic-sanising Apr 9, 2025
1691a08
Add missing params
quic-sanising Apr 9, 2025
02389f8
Update retain state names with past keyword
quic-sanising Apr 18, 2025
7dfdda4
Add output_names for sampler
quic-sanising Apr 18, 2025
d48d084
Optimizations (#2)
quic-sanising Apr 24, 2025
bf367a6
Merge branch 'main' into on-device-sampling
quic-sanising Apr 24, 2025
aa7206d
Fix bugs
quic-sanising Apr 24, 2025
14eefb9
Add files via upload
quic-sanising Apr 25, 2025
813a644
Handle invalid position_ids
quic-sanising Apr 25, 2025
929f51f
Use CtxScatterFuncCB3D instead of scatter_
quic-sanising May 2, 2025
f40f8c0
Fix missing import
quic-sanising May 2, 2025
f9f4ac9
Remove unsupported nonzero function
quic-sanising May 2, 2025
19f8d49
Reformat code
quic-sanising May 2, 2025
e328d8e
Reformat code
quic-sanising May 2, 2025
9f2c061
Merge branch 'main' into on-device-sampling
quic-sanising May 8, 2025
c59e1ab
Add include_sampler check
quic-sanising May 8, 2025
40f176e
Merge branch 'main' into on-device-sampling
quic-sanising May 8, 2025
52a077f
Update doc-strings
quic-sanising May 8, 2025
0858e4d
Merge branch 'on-device-sampling' into optimizations2
quic-sanising May 8, 2025
49dc0f3
Mask invalid tokens with INT_MAX
quic-sanising May 8, 2025
b5245de
Add predication logic for prefill and decode
quic-sanising May 20, 2025
1a0d2ca
Extend last_accepted_output_tokens up to sequence_length
quic-sanising May 21, 2025
6a4e970
Make generic sampler
quic-sanising May 23, 2025
411900e
Reformat code
quic-sanising May 23, 2025
de2c209
Add sampling constants
quic-sanising May 23, 2025
bcdb2f0
Merge branch 'main' into optimizations2
quic-sanising May 23, 2025
ee850b2
Fix bug
quic-sanising May 23, 2025
a58e9af
Move max_top_k_ids to qaic_config
quic-sanising May 27, 2025
485330c
Create function to get sampling inputs and outputs for ONNX export
quic-sanising May 30, 2025
0a079cd
Fix bug
quic-sanising Jun 3, 2025
413195c
Fix scalar tensor error and revert batch_size to 1
quic-sanising Jun 4, 2025
f987bde
Merge branch 'main' into on-device-sampling
quic-sanising Jun 5, 2025
d46f35f
Add Qualcomm signature and license
quic-sanising Jun 10, 2025
c20a49d
Sort imports
quic-sanising Jun 10, 2025
4923ba6
Update doc strings
quic-sanising Jun 11, 2025
1f96a9c
Remove false check
quic-sanising Jun 11, 2025
68313a3
Merge branch 'main' into on-device-sampling
quic-sanising Jun 11, 2025
d87a99d
Run linter
quic-sanising Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 133 additions & 10 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from pathlib import Path
from time import perf_counter
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -38,6 +38,7 @@
CustomOpsTransform,
KVCacheModuleMethodMapperTransform,
KVCacheTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
VlmNoKVOffloadTransform,
Expand Down Expand Up @@ -1281,8 +1282,17 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
``Mandatory`` Args:
:model (nn.Module): PyTorch model
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
:qaic_config (Optional[dict]): Qaic config with supported keys of `speculative_model_type` to specify speculative decoding TLM models.

``Optional`` Args:
:qaic_config (dict): QAIC config dictionary with the following supported keys:
:speculative_model_type (str): To specify Speculative Decoding Target Language Models.
:include_sampler (bool): Enable/Disable sampling of next tokens.
:return_pdfs (bool): Return probability distributions along with sampled
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
:max_top_k_ids (int): Specify the maximum number of top K tokens
(<= vocab size) to consider during sampling. The values provided in
`top_ks` tensor must be less than this maximum limit.

.. code-block:: python

Expand Down Expand Up @@ -1336,10 +1346,20 @@ def __init__(
self.model.config.use_cache = True
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
self.model.qaic_config = qaic_config

self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
self.is_tlm = transformed
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)

# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
# previous transform function.
self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs)
if self.is_tlm:
self.model.qaic_config["return_pdfs"] = True

@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
Expand Down Expand Up @@ -1368,9 +1388,17 @@ def from_pretrained(
Args:
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
:continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.

:qaic_config (Optional[dict]): Qaic config with supported keys of `speculative_model_type` to specify speculative decoding TLM models.
:args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM.
``Optional`` Args:
:qaic_config (dict): QAIC config dictionary with the following supported keys:
:speculative_model_type (str): To specify Speculative Decoding Target Language Models.
:include_sampler (bool): Enable/Disable sampling of next tokens.
:return_pdfs (bool): Return probability distributions along with sampled
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
:max_top_k_ids (int): Specify the maximum number of top K tokens
(<= vocab size) to consider during sampling. The values provided in
`top_ks` tensor must be less than this maximum limit.

.. code-block:: python

Expand Down Expand Up @@ -1428,6 +1456,7 @@ def model_hash(self) -> str:
mhash.update(to_hashable(self.model.config.to_diff_dict()))
mhash.update(to_hashable({"continuous_batching": self.continuous_batching}))
mhash.update(to_hashable({"is_tlm": self.is_tlm}))
mhash.update(to_hashable({"qaic_config": self.model.qaic_config}))
mhash.update(to_hashable(self._transform_names()))
mhash.update(to_hashable(self.pretrained_model_name_or_path))
mhash = mhash.hexdigest()[:16]
Expand All @@ -1449,7 +1478,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
fbs = constants.ONNX_EXPORT_EXAMPLE_FBS
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
kv_cache_shape = get_padding_shape_from_config(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
Expand All @@ -1472,7 +1501,13 @@ def export(self, export_dir: Optional[str] = None) -> str:
0: "full_batch_size" if self.continuous_batching else "batch_size",
2: "ctx_len",
}
output_names = ["logits"]
output_names = []
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
if self.model.qaic_config.get("return_pdfs", False):
output_names.append("probs")
output_names.append("next_tokens")
else:
output_names.append("logits")

for i in range(self.num_layers):
for kv in ["key", "value"]:
Expand All @@ -1489,13 +1524,84 @@ def export(self, export_dir: Optional[str] = None) -> str:
example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1)
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}

if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
example_inputs=example_inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
)

return self._export(
example_inputs,
output_names,
dynamic_axes,
export_dir=export_dir,
)

def get_sampling_inputs_and_outputs(
self,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a docstring to enhance clarity for other developers? Additionally, could we relocate this to the sampler folder? This will help keep modeling_auto.py streamlined and avoid unnecessary complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the doc-string. Cannot to relocate this function to sampler as it is a member function of QEFFAutoModelForCausalLM.

):
"""
Update the example inputs and outputs with respect to the On Device Sampler
for the ONNX export.
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
)
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
)
dynamic_axes["temperatures"] = {0: "batch_size"}

max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
dynamic_axes["top_ps"] = {0: "batch_size"}

example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
Expand Down Expand Up @@ -1608,6 +1714,14 @@ def compile(
"enable `continuous_batching=True` in `from_pretrained`."
)

if (
self.model.qaic_config is not None
and self.model.qaic_config.get("include_sampler", False)
and num_speculative_tokens is not None
and num_speculative_tokens > 0
):
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.")

# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size

Expand All @@ -1617,12 +1731,21 @@ def compile(
if prefill_only is None or prefill_only or prefill_seq_len == 1:
specializations.append(
self.build_prefill_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
)
)
if prefill_only is None or not prefill_only:
decode_spec = self.build_decode_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
)
if decode_spec:
specializations.append(decode_spec)
Expand Down
38 changes: 38 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward

SPD_TARGET = "target"
Expand Down Expand Up @@ -456,6 +457,43 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -
return model, transformed


class SamplerTransform:
"""
Add nodes at the output of any generic QEffForCausalLM model to enable the
sampling of next tokens at the device (instead of the host) and return the
next tokens and/or probability distributions.

Note: To achieve this, the generic QEffForCausalLM model must provide the
logits as output.

``Mandatory`` Args:
:model (nn.Module): PyTorch model.

Returns:
:model (nn.Module): PyTorch model.
:transformed (bool): whether transformation was applied successfully.
"""

# supported architectures
_module_mapping = {
# Llama
QEffLlamaForCausalLM,
}

@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
transformed = False
if qaic_config is None or not qaic_config.get("include_sampler", False):
return model, transformed
elif (model_class := model.__class__) in cls._module_mapping:
model.old_forward = model.forward
model.forward = MethodType(sampler_forward, model)
transformed = True
else:
raise NotImplementedError(f"Model class {model_class} does not support on device sampling.")
return model, transformed


class VlmKVOffloadTransform(ModuleMappingTransform):
# supported architectures
_module_mapping = {
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading
Loading