Skip to content

On Device Sampling #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
718d763
Initial commit
quic-sanising Apr 8, 2025
b8d099e
Reformat code
quic-sanising Apr 8, 2025
544c0dd
Fix bug
quic-sanising Apr 8, 2025
0b4d0a9
Add Gumbel-Max trick based random sampling
quic-sanising Apr 8, 2025
24efc93
Bring up to date
quic-sanising Apr 8, 2025
2af43c6
Use Gumbel-Max Trick based Random Sampling as default
quic-sanising Apr 8, 2025
3eca771
Clip k to max value
quic-sanising Apr 8, 2025
b0e9162
Add docstring for sampling parameters
quic-sanising Apr 8, 2025
0486e42
Fix bug
quic-sanising Apr 8, 2025
e7dda72
Add support for continuous batching
quic-sanising Apr 8, 2025
f94c657
Fix ONNX error for batch_size 1 treated as a Constant
quic-sanising Apr 8, 2025
fa026a4
Undo docstring deletion
quic-sanising Apr 8, 2025
eff2007
Remove device and unncessary reshapes
quic-sanising Apr 8, 2025
ebfbaea
Revert batch_size to 1
quic-sanising Apr 8, 2025
83d33ac
Remove vocab_size from dynamic axes
quic-sanising Apr 8, 2025
fc3dc82
Change condition
quic-sanising Apr 8, 2025
abbaf53
Change size of each sampling parameter to (batch_size, 1)
quic-sanising Apr 8, 2025
f5f5e2d
Reformat code
quic-sanising Apr 8, 2025
05c0bf0
Fix bug
quic-sanising Apr 8, 2025
3b63ecb
Allow chunked prompts during prefill
quic-sanising Apr 8, 2025
0b6873c
Merge remote-tracking branch 'upstream/main' into on-device-sampling
quic-sanising Apr 9, 2025
1691a08
Add missing params
quic-sanising Apr 9, 2025
02389f8
Update retain state names with past keyword
quic-sanising Apr 18, 2025
7dfdda4
Add output_names for sampler
quic-sanising Apr 18, 2025
d48d084
Optimizations (#2)
quic-sanising Apr 24, 2025
bf367a6
Merge branch 'main' into on-device-sampling
quic-sanising Apr 24, 2025
aa7206d
Fix bugs
quic-sanising Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
CustomOpsTransform,
KVCacheModuleMethodMapperTransform,
KVCacheTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
VlmNoKVOffloadTransform,
Expand Down Expand Up @@ -75,7 +76,7 @@ def __repr__(self) -> str:

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make them optional parameters

if kwargs.get("attn_implementation", None) not in {None, "eager"}:
logger.warning('Updating attn_implementation="eager"')

Expand All @@ -85,7 +86,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, is_tlm=is_tlm)
return cls(model, is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs)

@property
def model_name(self) -> str:
Expand Down Expand Up @@ -1262,6 +1263,8 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
:model (nn.Module): PyTorch model
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
:include_sampler (bool): Enable/Disable sampling of next tokens during decode.
:return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model.


.. code-block:: python
Expand Down Expand Up @@ -1292,6 +1295,8 @@ def __init__(
model: nn.Module,
continuous_batching: bool = False,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
**kwargs,
):
model_class_name = model.__class__.__name__
Expand Down Expand Up @@ -1321,8 +1326,14 @@ def __init__(
if is_tlm:
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
self.model, transformed = SpDTransform.apply(self.model)
self.model.return_pdfs = True
Copy link
Contributor

@quic-hemagnih quic-hemagnih Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the code for handling is_tlm == FALSE condition for population of return_pdfs

self.is_tlm = is_tlm

if include_sampler: # Sampling
self.model, transformed = SamplerTransform.apply(self.model)
self.model.return_pdfs = return_pdfs
self.include_sampler = include_sampler

@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
Expand All @@ -1336,7 +1347,7 @@ def __repr__(self) -> str:
@classmethod
@with_replaced_quantizers
def from_pretrained(
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs
):
"""
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
Expand All @@ -1347,6 +1358,8 @@ def from_pretrained(
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
:continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
:include_sampler (bool): Enable/Disable sampling of next tokens during decode.
:return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model.
:args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM.

.. code-block:: python
Expand Down Expand Up @@ -1389,7 +1402,7 @@ def from_pretrained(
model, kv_offload=kv_offload
)

return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching)
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching, include_sampler=include_sampler, return_pdfs=return_pdfs)

@property
def model_hash(self) -> str:
Expand All @@ -1398,6 +1411,7 @@ def model_hash(self) -> str:
mhash.update(to_hashable(self.model.config.to_diff_dict()))
mhash.update(to_hashable({"continuous_batching": self.continuous_batching}))
mhash.update(to_hashable({"is_tlm": self.is_tlm}))
mhash.update(to_hashable({"include_sampler": self.include_sampler}))
mhash.update(to_hashable(self._transform_names()))
mhash = mhash.hexdigest()[:16]
return mhash
Expand Down Expand Up @@ -1441,7 +1455,13 @@ def export(self, export_dir: Optional[str] = None) -> str:
0: "full_batch_size" if self.continuous_batching else "batch_size",
2: "ctx_len",
}
output_names = ["logits"]
output_names = []
if self.include_sampler:
if self.model.return_pdfs:
output_names.append("probs")
output_names.append("next_tokens")
else:
output_names.append("logits")

for i in range(self.num_layers):
for kv in ["key", "value"]:
Expand All @@ -1458,6 +1478,48 @@ def export(self, export_dir: Optional[str] = None) -> str:
example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1)
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}

if self.include_sampler:
nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep
max_top_k_ids = constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS

example_inputs["last_accepted_output_tokens"] = torch.randint(low=0, high=self.model.config.vocab_size, size=(bs, nlk))
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "num_logits_to_keep"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = torch.ones((bs, 1), dtype=torch.float) * 0.5
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = torch.zeros((bs, 1), dtype=torch.float) + 0.5
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float)
dynamic_axes["temperatures"] = {0: "batch_size"}

example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80
dynamic_axes["top_ps"] = {0: "batch_size"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define constants for 0.80 and 0.99


example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.99
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return self._export(
example_inputs,
output_names,
Expand All @@ -1472,12 +1534,14 @@ def build_prefill_specialization(
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
max_top_k_ids: Optional[int] = None,
):
spec = {
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"num_logits_to_keep": 1 if self.is_tlm else None,
"num_logits_to_keep": 1 if self.is_tlm or self.include_sampler else None,
"max_top_k_ids": max_top_k_ids if self.include_sampler else None,
}
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
Expand All @@ -1495,14 +1559,16 @@ def build_decode_specialization(
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
max_top_k_ids: Optional[int] = None,
):
if prefill_seq_len == 1 and not self.continuous_batching:
return None # Avoid duplication with prefill
spec = {
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
"ctx_len": ctx_len,
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm or self.include_sampler else None,
"max_top_k_ids": max_top_k_ids if self.include_sampler else None,
}
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
Expand Down Expand Up @@ -1592,12 +1658,23 @@ def compile(
if prefill_only is None or prefill_only or prefill_seq_len == 1:
specializations.append(
self.build_prefill_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None,
)
)
if prefill_only is None or not prefill_only:
decode_spec = self.build_decode_specialization(
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None,
)
if decode_spec:
specializations.append(decode_spec)
Expand Down Expand Up @@ -1626,7 +1703,6 @@ def compile(
mxint8_kv_cache=mxint8_kv_cache,
**compiler_options,
)

return qpc_path

# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
Expand Down
32 changes: 31 additions & 1 deletion QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward


class CustomOpsTransform(ModuleMappingTransform):
_module_mapping = {
GemmaRMSNorm: GemmaCustomRMSNormAIC,
Expand Down Expand Up @@ -439,6 +439,36 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
return model, transformed


class SamplerTransform:
"""
``Mandatory`` Args:
:model (nn.Module): PyTorch model.

Returns:
:model (nn.Module): PyTorch model.
:transformed (bool): whether transformation was applied successfully.
"""

# supported architectures
_module_mapping = {
# Llama
QEffLlamaForCausalLM,
}

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add doc string

if (model_class := model.__class__) in cls._module_mapping:
model.forward = MethodType(sampler_forward, model)
transformed = True
else:
raise NotImplementedError(
f"model class {model_class} does not yet support returning multiple logits to keep."
)

return model, transformed


class VlmKVOffloadTransform(ModuleMappingTransform):
# supported architectures
_module_mapping = {
Expand Down
Empty file.
Loading