Skip to content

Commit 5bcedfe

Browse files
committed
QNN Compilation path Support in QEFFBaseModel class.
Signed-off-by: Shubham Agrawal <[email protected]>
1 parent fc89e8b commit 5bcedfe

11 files changed

+319
-437
lines changed

QEfficient/base/modeling_qeff.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def compile(self, *args, **kwargs) -> Path:
9898
:num_cores (int): Number of cores to utilize in each device ``Defaults to 16``.
9999
:mxfp6_matmul (bool): Use MXFP6 to compress weights for MatMul nodes to run faster on device. ``Defaults to False``.
100100
:mxint8_kv_cache (bool): Use MXINT8 to compress KV-cache on device to access and update KV-cache faster. ``Defaults to False``.
101-
:compiler_options: Pass any compiler option as input. Any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
101+
:compiler_options: Pass any compiler option as input.
102+
Following flag can be passed in compiler_options to enable QNN Compilation path.
103+
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
104+
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
105+
any other parameter passed will be ignored in QNN compilation path as we expect overriding or extra parameters for QNN via config file.
106+
for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
102107
- aic_num_cores=16 -> -aic-num-cores=16
103108
- convert_to_fp16=True -> -convert-to-fp16
104109
@@ -217,6 +222,7 @@ def _compile(
217222
onnx_path: Optional[str] = None,
218223
compile_dir: Optional[str] = None,
219224
*,
225+
mxint8_kv_cache: bool = False,
220226
specializations: Optional[List[Dict[str, int]]] = None,
221227
custom_io: Optional[Dict[str, str]] = None,
222228
mdp_ts_num_devices: int = 1,
@@ -233,10 +239,32 @@ def _compile(
233239
:custom_io (dict): Custom IO to specify the input and outputs in different formats than default
234240
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
235241
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
236-
:compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
242+
:compiler_options: Pass any compiler option as input.
243+
Following flag can be passed in compiler_options to enable QNN Compilation path.
244+
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
245+
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
246+
any other parameter passed will be ignored in QNN compilation path as we expect overriding or extra parameters for QNN via config file.
247+
for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
237248
- aic_num_cores=16 -> -aic-num-cores=16
238249
- convert_to_fp16=True -> -convert-to-fp16
250+
239251
"""
252+
enable_qnn = compiler_options["enable_qnn"] if "enable_qnn" in compiler_options else False
253+
qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None
254+
255+
if enable_qnn:
256+
return self._qnn_compile(
257+
onnx_path,
258+
compile_dir,
259+
specializations=specializations,
260+
custom_io=custom_io,
261+
mdp_ts_num_devices=mdp_ts_num_devices,
262+
num_cores=compiler_options.get("aic_num_cores", 16),
263+
mxfp6_matmul=compiler_options.get("mxfp6_matmul", False),
264+
mxint8_kv_cache=mxint8_kv_cache,
265+
qnn_config=qnn_config,
266+
)
267+
240268
if onnx_path is None and self.onnx_path is None:
241269
self.export()
242270

@@ -346,17 +374,13 @@ def _qnn_compile(
346374
onnx_path: Optional[str] = None,
347375
compile_dir: Optional[str] = None,
348376
*,
377+
custom_io: Optional[Dict[str, str]] = None,
349378
specializations: Optional[List[Dict[str, int]]] = None,
350-
prefill_seq_len: int = 32,
351-
ctx_len: int = 128,
352-
batch_size: int = 1,
353-
full_batch_size: Optional[int] = None,
354379
mdp_ts_num_devices: int = 1,
355380
num_cores: int = 16,
356381
mxfp6_matmul: bool = False,
357382
mxint8_kv_cache: bool = False,
358383
qnn_config: Optional[str] = None,
359-
kv_cache_batch_size: Optional[int] = None,
360384
) -> str:
361385
"""
362386
Interface for QNN compiler
@@ -365,16 +389,11 @@ def _qnn_compile(
365389
:onnx_path (str): Onnx file to compile
366390
:compile_dir (str): Directory path to compile the qpc. A suffix is added to the directory path to avoid reusing same qpc for different parameters.
367391
:specializations (list): List of specializations to compile for
368-
:prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``.
369-
:ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``.
370-
:batch_size (int, optional): Batch size. ``Defaults to 1``.
371-
:full_batch_size (int, optional): Continuous batching batch size.
372392
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
373393
:num_cores (int): Number of cores used to compile the model.
374394
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``.
375395
:mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``.
376396
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
377-
:kv_cache_batch_size (int): kv_cache_batch_size for Prefix Caching. ``Defaults to None.``
378397
"""
379398
if onnx_path is None and self.onnx_path is None:
380399
self.export()
@@ -390,6 +409,9 @@ def _qnn_compile(
390409
if specializations is not None:
391410
compile_hash.update(to_hashable(specializations))
392411

412+
if custom_io is not None:
413+
compile_hash.update(to_hashable(custom_io))
414+
393415
if qnn_config is not None:
394416
qnn_config_values = load_json(qnn_config)
395417
compile_hash.update(to_hashable(qnn_config_values))
@@ -426,15 +448,12 @@ def _qnn_compile(
426448
qpc_base_path=compile_dir,
427449
num_cores=num_cores,
428450
device_group=list(range(mdp_ts_num_devices)),
429-
batch_size=batch_size,
430-
prompt_len=prefill_seq_len,
431-
ctx_len=ctx_len,
432451
mxfp6=mxfp6_matmul,
433452
mxint8=mxint8_kv_cache,
434-
full_batch_size=full_batch_size,
435453
qnn_config=qnn_config,
436454
qnn_binary_dir=qpc_path,
437-
kv_cache_batch_size=kv_cache_batch_size,
455+
specializations=specializations,
456+
custom_io=custom_io,
438457
)
439458

440459
self.qpc_path = qpc_path

QEfficient/compile/qnn_compiler.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
import os
99
import shutil
10-
from typing import List, Optional
10+
from typing import Dict, List, Optional
1111

1212
from QEfficient.utils._utils import create_json, execute_command, load_json
1313
from QEfficient.utils.constants import QnnConstants
14-
from QEfficient.utils.generate_qnn_network_specialization_config import fetch_nodes_info, generate_data_format_config
14+
from QEfficient.utils.generate_qnn_network_specialization_config import (
15+
generate_data_format_config,
16+
generate_qnn_specialization,
17+
)
1518
from QEfficient.utils.logging_utils import logger
1619

1720

@@ -31,9 +34,6 @@ def __init__(
3134
device_group: Optional[List[int]] = None,
3235
compiler_enable_depth_first: bool = False,
3336
compiler_max_out_channel_split: int = -1,
34-
batch_size: int = 1,
35-
prompt_len: int = 32,
36-
ctx_len: int = 128,
3737
compiler_mxfp6_matmul_weights: bool = True,
3838
qnn_target: str = QnnConstants.TARGET,
3939
qnn_config_path: Optional[str] = None,
@@ -48,9 +48,6 @@ def __init__(
4848
self.device_group = device_group
4949
self.compiler_enable_depth_first = compiler_enable_depth_first
5050
self.compiler_max_out_channel_split = compiler_max_out_channel_split
51-
self.batch_size = batch_size
52-
self.prompt_len = prompt_len
53-
self.ctx_len = ctx_len
5451
self.compiler_mxfp6_matmul_weights = compiler_mxfp6_matmul_weights
5552
self.qnn_config_path = qnn_config_path
5653
self.qnn_binary_dir = qnn_binary_dir
@@ -327,16 +324,15 @@ def compile(
327324
device_group: Optional[List[int]] = None,
328325
aic_enable_depth_first: bool = False,
329326
mos: int = -1,
330-
batch_size: int = 1,
331-
prompt_len: int = 32,
332-
ctx_len: int = 128,
333327
mxfp6: bool = True,
334328
mxint8: bool = False,
335329
allow_mxint8_mdp_io: Optional[bool] = False,
336330
full_batch_size=None,
337331
qnn_config: Optional[str] = None,
338332
qnn_binary_dir: Optional[str] = None,
339333
kv_cache_batch_size: Optional[int] = None,
334+
custom_io: Optional[Dict[str, str]] = None,
335+
specializations: Optional[List[Dict[str, int]]] = None,
340336
**kwargs,
341337
) -> str:
342338
"""
@@ -377,16 +373,11 @@ def compile(
377373
# TODO To make custom_io_config.yaml configurable as not all models need it.
378374
custom_io_file_path = os.path.join(qpc_base_path, "custom_io_config.yaml")
379375

380-
kv_precision = "uint8" if mxint8 else "float16"
381-
fetch_nodes_info(
376+
generate_qnn_specialization(
382377
onnx_graph_path=onnx_path,
383-
batch_size=batch_size,
384-
sequence_length=prompt_len,
385-
context_length=ctx_len,
378+
specializations=specializations,
379+
custom_io=custom_io,
386380
file_path=custom_io_file_path,
387-
full_batch_size=full_batch_size,
388-
kv_precision=kv_precision,
389-
kv_cache_batch_size=kv_cache_batch_size,
390381
)
391382

392383
if not os.path.isfile(custom_io_file_path):
@@ -403,9 +394,6 @@ def compile(
403394
custom_io_path=custom_io_file_path,
404395
compiler_enable_depth_first=aic_enable_depth_first,
405396
compiler_max_out_channel_split=mos,
406-
batch_size=batch_size,
407-
prompt_len=prompt_len,
408-
ctx_len=ctx_len,
409397
compiler_mxfp6_matmul_weights=mxfp6,
410398
qnn_binary_dir=qnn_binary_dir,
411399
mxint8=mxint8,

QEfficient/peft/auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def compile(
251251
custom_io=custom_io,
252252
mdp_ts_num_devices=num_devices,
253253
aic_num_cores=num_cores,
254+
mxint8_kv_cache=mxint8_kv_cache,
254255
**compiler_options,
255256
)
256257

QEfficient/transformers/models/modeling_auto.py

+30-73
Original file line numberDiff line numberDiff line change
@@ -598,21 +598,12 @@ def compile(
598598
mxfp6_matmul: bool = False,
599599
mxint8_kv_cache: bool = False,
600600
num_speculative_tokens: Optional[int] = None,
601-
enable_qnn: bool = False,
602-
qnn_config: Optional[str] = None,
603601
**compiler_options,
604602
) -> str:
605-
if (
606-
any(
607-
param is not None
608-
for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens, qnn_config]
609-
)
610-
or enable_qnn
611-
):
603+
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
612604
raise ValueError(
613-
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens', and 'qnn_config' to be None, and 'enable_qnn' to be False but got: "
605+
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: "
614606
f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, "
615-
f"enable_qnn={enable_qnn}, qnn_config={qnn_config}"
616607
)
617608

618609
output_names = self.model.get_output_names(kv_offload=True)
@@ -651,6 +642,7 @@ def compile(
651642
mdp_ts_num_devices=num_devices,
652643
aic_num_cores=num_cores,
653644
custom_io=custom_io_vision,
645+
mxint8_kv_cache=mxint8_kv_cache,
654646
**compiler_options,
655647
)
656648

@@ -675,6 +667,7 @@ def compile(
675667
mdp_ts_num_devices=num_devices,
676668
aic_num_cores=num_cores,
677669
custom_io=custom_io_lang,
670+
mxint8_kv_cache=mxint8_kv_cache,
678671
**compiler_options,
679672
)
680673
return self.qpc_path
@@ -915,21 +908,12 @@ def compile(
915908
mxfp6_matmul: bool = False,
916909
mxint8_kv_cache: bool = False,
917910
num_speculative_tokens: Optional[int] = None,
918-
enable_qnn: bool = False,
919-
qnn_config: Optional[str] = None,
920911
**compiler_options,
921912
) -> str:
922-
if (
923-
any(
924-
param is not None
925-
for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens, qnn_config]
926-
)
927-
or enable_qnn
928-
):
913+
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
929914
raise ValueError(
930-
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens', and 'qnn_config' to be None, and 'enable_qnn' to be False but got: "
915+
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: "
931916
f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, "
932-
f"enable_qnn={enable_qnn}, qnn_config={qnn_config}"
933917
)
934918

935919
output_names = self.model.get_output_names()
@@ -967,6 +951,7 @@ def compile(
967951
custom_io=custom_io,
968952
mdp_ts_num_devices=num_devices,
969953
aic_num_cores=num_cores,
954+
mxint8_kv_cache=mxint8_kv_cache,
970955
**compiler_options,
971956
)
972957
return self.qpc_path
@@ -1476,8 +1461,6 @@ def compile(
14761461
mxfp6_matmul: bool = False,
14771462
mxint8_kv_cache: bool = False,
14781463
num_speculative_tokens: Optional[int] = None,
1479-
enable_qnn: bool = False,
1480-
qnn_config: Optional[str] = None,
14811464
**compiler_options,
14821465
) -> str:
14831466
"""
@@ -1499,8 +1482,6 @@ def compile(
14991482
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
15001483
:mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
15011484
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1502-
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1503-
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
15041485
15051486
Returns:
15061487
:str: Path of the compiled ``qpc`` package.
@@ -1562,48 +1543,29 @@ def compile(
15621543
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
15631544
specializations.append(decode_specialization)
15641545

1565-
if enable_qnn:
1566-
if compiler_options:
1567-
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1568-
1569-
qpc_path = self._qnn_compile(
1570-
onnx_path,
1571-
compile_dir,
1572-
specializations=specializations,
1573-
prefill_seq_len=prefill_seq_len,
1574-
ctx_len=ctx_len,
1575-
batch_size=batch_size,
1576-
full_batch_size=full_batch_size,
1577-
mdp_ts_num_devices=num_devices,
1578-
num_cores=num_cores,
1579-
mxfp6_matmul=mxfp6_matmul,
1580-
mxint8_kv_cache=mxint8_kv_cache,
1581-
qnn_config=qnn_config,
1582-
kv_cache_batch_size=kv_cache_batch_size,
1583-
)
1584-
else:
1585-
# Custom IO
1586-
custom_io = {}
1587-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1588-
for suffix in ["", "_RetainedState"]:
1589-
for i in range(self.num_layers):
1590-
for kv in ["key", "value"]:
1591-
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1592-
1593-
qpc_path = self._compile(
1594-
onnx_path,
1595-
compile_dir,
1596-
compile_only=True,
1597-
retained_state=True,
1598-
specializations=specializations,
1599-
convert_to_fp16=True,
1600-
mxfp6_matmul=mxfp6_matmul,
1601-
custom_io=custom_io,
1602-
mdp_ts_num_devices=num_devices,
1603-
num_speculative_tokens=num_speculative_tokens,
1604-
aic_num_cores=num_cores,
1605-
**compiler_options,
1606-
)
1546+
# Custom IO
1547+
custom_io = {}
1548+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1549+
for suffix in ["", "_RetainedState"]:
1550+
for i in range(self.num_layers):
1551+
for kv in ["key", "value"]:
1552+
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1553+
1554+
qpc_path = self._compile(
1555+
onnx_path,
1556+
compile_dir,
1557+
compile_only=True,
1558+
retained_state=True,
1559+
specializations=specializations,
1560+
convert_to_fp16=True,
1561+
mxfp6_matmul=mxfp6_matmul,
1562+
custom_io=custom_io,
1563+
mdp_ts_num_devices=num_devices,
1564+
num_speculative_tokens=num_speculative_tokens,
1565+
aic_num_cores=num_cores,
1566+
mxint8_kv_cache=mxint8_kv_cache,
1567+
**compiler_options,
1568+
)
16071569
return qpc_path
16081570

16091571
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1747,8 +1709,6 @@ def compile(
17471709
mxfp6_matmul: bool = False,
17481710
mxint8_kv_cache: bool = False,
17491711
num_speculative_tokens: Optional[int] = None,
1750-
enable_qnn: bool = False,
1751-
qnn_config: Optional[str] = None,
17521712
**compiler_options,
17531713
) -> str:
17541714
"""
@@ -1790,9 +1750,6 @@ def compile(
17901750
if num_speculative_tokens:
17911751
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
17921752

1793-
if enable_qnn or qnn_config:
1794-
logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq")
1795-
17961753
return self._compile(
17971754
onnx_path,
17981755
compile_dir,

0 commit comments

Comments
 (0)