-
Notifications
You must be signed in to change notification settings - Fork 43
On Device Sampling #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
On Device Sampling #350
Changes from all commits
718d763
b8d099e
544c0dd
0b4d0a9
24efc93
2af43c6
3eca771
b0e9162
0486e42
e7dda72
f94c657
fa026a4
eff2007
ebfbaea
83d33ac
fc3dc82
abbaf53
f5f5e2d
05c0bf0
3b63ecb
0b6873c
1691a08
02389f8
7dfdda4
d48d084
bf367a6
aa7206d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
CustomOpsTransform, | ||
KVCacheModuleMethodMapperTransform, | ||
KVCacheTransform, | ||
SamplerTransform, | ||
SpDTransform, | ||
VlmKVOffloadTransform, | ||
VlmNoKVOffloadTransform, | ||
|
@@ -75,7 +76,7 @@ def __repr__(self) -> str: | |
|
||
@classmethod | ||
@with_replaced_quantizers | ||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): | ||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs): | ||
if kwargs.get("attn_implementation", None) not in {None, "eager"}: | ||
logger.warning('Updating attn_implementation="eager"') | ||
|
||
|
@@ -85,7 +86,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals | |
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
|
||
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | ||
return cls(model, is_tlm=is_tlm) | ||
return cls(model, is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs) | ||
|
||
@property | ||
def model_name(self) -> str: | ||
|
@@ -1262,6 +1263,8 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): | |
:model (nn.Module): PyTorch model | ||
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. | ||
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens during decode. | ||
:return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
|
||
|
||
.. code-block:: python | ||
|
@@ -1292,6 +1295,8 @@ def __init__( | |
model: nn.Module, | ||
continuous_batching: bool = False, | ||
is_tlm: bool = False, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
**kwargs, | ||
): | ||
model_class_name = model.__class__.__name__ | ||
|
@@ -1321,8 +1326,14 @@ def __init__( | |
if is_tlm: | ||
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch | ||
self.model, transformed = SpDTransform.apply(self.model) | ||
self.model.return_pdfs = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the code for handling is_tlm == FALSE condition for population of return_pdfs |
||
self.is_tlm = is_tlm | ||
|
||
if include_sampler: # Sampling | ||
self.model, transformed = SamplerTransform.apply(self.model) | ||
self.model.return_pdfs = return_pdfs | ||
self.include_sampler = include_sampler | ||
|
||
@property | ||
def model_name(self) -> str: | ||
mname = self.model.__class__.__name__ | ||
|
@@ -1336,7 +1347,7 @@ def __repr__(self) -> str: | |
@classmethod | ||
@with_replaced_quantizers | ||
def from_pretrained( | ||
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs | ||
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs | ||
): | ||
""" | ||
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. | ||
|
@@ -1347,6 +1358,8 @@ def from_pretrained( | |
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. | ||
:continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. | ||
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens during decode. | ||
:return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
:args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. | ||
|
||
.. code-block:: python | ||
|
@@ -1389,7 +1402,7 @@ def from_pretrained( | |
model, kv_offload=kv_offload | ||
) | ||
|
||
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching) | ||
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching, include_sampler=include_sampler, return_pdfs=return_pdfs) | ||
|
||
@property | ||
def model_hash(self) -> str: | ||
|
@@ -1398,6 +1411,7 @@ def model_hash(self) -> str: | |
mhash.update(to_hashable(self.model.config.to_diff_dict())) | ||
mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) | ||
mhash.update(to_hashable({"is_tlm": self.is_tlm})) | ||
mhash.update(to_hashable({"include_sampler": self.include_sampler})) | ||
mhash.update(to_hashable(self._transform_names())) | ||
mhash = mhash.hexdigest()[:16] | ||
return mhash | ||
|
@@ -1441,7 +1455,13 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
2: "ctx_len", | ||
} | ||
output_names = ["logits"] | ||
output_names = [] | ||
if self.include_sampler: | ||
if self.model.return_pdfs: | ||
output_names.append("probs") | ||
output_names.append("next_tokens") | ||
else: | ||
output_names.append("logits") | ||
|
||
for i in range(self.num_layers): | ||
for kv in ["key", "value"]: | ||
|
@@ -1458,6 +1478,48 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) | ||
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} | ||
|
||
if self.include_sampler: | ||
nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep | ||
max_top_k_ids = constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS | ||
|
||
example_inputs["last_accepted_output_tokens"] = torch.randint(low=0, high=self.model.config.vocab_size, size=(bs, nlk)) | ||
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "num_logits_to_keep"} | ||
|
||
example_inputs["past_repetition_penalty_buffer"] = torch.zeros( | ||
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool) | ||
dynamic_axes["past_repetition_penalty_buffer"] = { | ||
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
} | ||
output_names.append("past_repetition_penalty_buffer_RetainedState") | ||
|
||
example_inputs["repetition_penalties"] = torch.ones((bs, 1), dtype=torch.float) * 0.5 | ||
dynamic_axes["repetition_penalties"] = {0: "batch_size"} | ||
|
||
example_inputs["past_presence_penalty_buffer"] = torch.zeros( | ||
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool) | ||
dynamic_axes["past_presence_penalty_buffer"] = { | ||
0: "full_batch_size" if self.continuous_batching else "batch_size", | ||
} | ||
output_names.append("past_presence_penalty_buffer_RetainedState") | ||
|
||
example_inputs["presence_penalties"] = torch.zeros((bs, 1), dtype=torch.float) + 0.5 | ||
dynamic_axes["presence_penalties"] = {0: "batch_size"} | ||
|
||
example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) | ||
dynamic_axes["temperatures"] = {0: "batch_size"} | ||
|
||
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) | ||
dynamic_axes["top_ks"] = {0: "batch_size"} | ||
|
||
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80 | ||
dynamic_axes["top_ps"] = {0: "batch_size"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we define constants for 0.80 and 0.99 |
||
|
||
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.99 | ||
dynamic_axes["min_ps"] = {0: "batch_size"} | ||
|
||
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) | ||
dynamic_axes["random_numbers"] = {0: "batch_size"} | ||
|
||
return self._export( | ||
example_inputs, | ||
output_names, | ||
|
@@ -1472,12 +1534,14 @@ def build_prefill_specialization( | |
batch_size: int = 1, | ||
kv_cache_batch_size: Optional[int] = None, | ||
full_batch_size: Optional[int] = None, | ||
max_top_k_ids: Optional[int] = None, | ||
): | ||
spec = { | ||
"batch_size": 1 if self.continuous_batching else batch_size, | ||
"seq_len": prefill_seq_len, | ||
"ctx_len": ctx_len, | ||
"num_logits_to_keep": 1 if self.is_tlm else None, | ||
"num_logits_to_keep": 1 if self.is_tlm or self.include_sampler else None, | ||
"max_top_k_ids": max_top_k_ids if self.include_sampler else None, | ||
} | ||
if self.continuous_batching: | ||
spec["full_batch_size"] = kv_cache_batch_size | ||
|
@@ -1495,14 +1559,16 @@ def build_decode_specialization( | |
kv_cache_batch_size: Optional[int] = None, | ||
full_batch_size: Optional[int] = None, | ||
num_speculative_tokens: Optional[int] = None, | ||
max_top_k_ids: Optional[int] = None, | ||
): | ||
if prefill_seq_len == 1 and not self.continuous_batching: | ||
return None # Avoid duplication with prefill | ||
spec = { | ||
"batch_size": full_batch_size if self.continuous_batching else batch_size, | ||
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, | ||
"ctx_len": ctx_len, | ||
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, | ||
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm or self.include_sampler else None, | ||
"max_top_k_ids": max_top_k_ids if self.include_sampler else None, | ||
} | ||
if self.continuous_batching: | ||
spec["full_batch_size"] = kv_cache_batch_size | ||
|
@@ -1592,12 +1658,23 @@ def compile( | |
if prefill_only is None or prefill_only or prefill_seq_len == 1: | ||
specializations.append( | ||
self.build_prefill_specialization( | ||
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
batch_size=batch_size, | ||
kv_cache_batch_size=kv_cache_batch_size, | ||
full_batch_size=full_batch_size, | ||
max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None, | ||
) | ||
) | ||
if prefill_only is None or not prefill_only: | ||
decode_spec = self.build_decode_specialization( | ||
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
batch_size=batch_size, | ||
kv_cache_batch_size=kv_cache_batch_size, | ||
full_batch_size=full_batch_size, | ||
num_speculative_tokens=num_speculative_tokens, | ||
max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None, | ||
) | ||
if decode_spec: | ||
specializations.append(decode_spec) | ||
|
@@ -1626,7 +1703,6 @@ def compile( | |
mxint8_kv_cache=mxint8_kv_cache, | ||
**compiler_options, | ||
) | ||
|
||
return qpc_path | ||
|
||
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -266,9 +266,9 @@ | |
QEffWhisperModel, | ||
QEffWhisperPositionalEmbedding, | ||
) | ||
from QEfficient.transformers.sampler.sampler import sampler_forward | ||
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward | ||
|
||
|
||
class CustomOpsTransform(ModuleMappingTransform): | ||
_module_mapping = { | ||
GemmaRMSNorm: GemmaCustomRMSNormAIC, | ||
|
@@ -439,6 +439,36 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: | |
return model, transformed | ||
|
||
|
||
class SamplerTransform: | ||
""" | ||
``Mandatory`` Args: | ||
:model (nn.Module): PyTorch model. | ||
|
||
Returns: | ||
:model (nn.Module): PyTorch model. | ||
:transformed (bool): whether transformation was applied successfully. | ||
""" | ||
|
||
# supported architectures | ||
_module_mapping = { | ||
# Llama | ||
QEffLlamaForCausalLM, | ||
} | ||
|
||
@classmethod | ||
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: | ||
transformed = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add doc string |
||
if (model_class := model.__class__) in cls._module_mapping: | ||
model.forward = MethodType(sampler_forward, model) | ||
transformed = True | ||
else: | ||
raise NotImplementedError( | ||
f"model class {model_class} does not yet support returning multiple logits to keep." | ||
) | ||
|
||
return model, transformed | ||
|
||
|
||
class VlmKVOffloadTransform(ModuleMappingTransform): | ||
# supported architectures | ||
_module_mapping = { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make them optional parameters