-
Notifications
You must be signed in to change notification settings - Fork 647
Description
🐛 Describe the bug
When trying to compile a model with the QNN partitioner with the GPU or DSP i get the following error:
[ERROR] [Qnn ExecuTorch]: Cannot Open QNN library libQnnDsp.so, with error: libQnnDsp.so: cannot open shared object file: No such file or directory
E 00:00:00.195781 executorch:QnnImplementation.cpp:183] Fail to start backend
E 00:00:00.195786 executorch:QnnManager.cpp:272] Fail to load Qnn library
Segmentation fault (core dumped)
When i run the HTP backend, this error does not occur. I have also followed the AI engine tutorial on the documentation page so I have setup my $QNN_SDK_ROOT and built.
This is my function to create the binary files:
from enum import IntEnum, unique
from typing import Optional
# Replaced generate_htp_compiler_spec in executorch/backends/qualcomm/utils/utils.py with below
'''
def generate_htp_compiler_spec(
use_fp16: bool,
use_dlbc: bool = False,
use_multi_contexts: bool = False,
performance_mode: QnnExecuTorchHtpPerformanceMode = QnnExecuTorchHtpPerformanceMode.kHtpBurst,
) -> QnnExecuTorchBackendOptions:
"""
Helper function generating backend options for QNN HTP
Args:
use_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
Note that not all SoC support QNN HTP fp16. Only premium tier SoC
like Snapdragon 8 Gen 1 or newer can support HTP fp16.
use_dlbc: Deep Learning Bandwidth Compression allows inputs to be
compressed, such that the processing bandwidth can be lowered.
use_multi_contexts: When multiple contexts are generated inside the same
pte, it is possible to reserve a single spill-fill allocation that
could be re-used across all the splits.
Returns:
QnnExecuTorchHtpBackendOptions: backend options for QNN HTP.
"""
htp_options = QnnExecuTorchHtpBackendOptions()
htp_options.precision = (
QnnExecuTorchHtpPrecision.kHtpFp16
if use_fp16
else QnnExecuTorchHtpPrecision.kHtpQuantized
)
# This actually is not an option which can affect the compiled blob.
# But we don't have other place to pass this option at execution stage.
# TODO: enable voting mechanism in runtime and make this as an option
htp_options.performance_mode = performance_mode
htp_options.use_multi_contexts = use_multi_contexts
htp_options.use_dlbc = use_dlbc
return QnnExecuTorchBackendOptions(
backend_type=QnnExecuTorchBackendType.kHtpBackend,
htp_options=htp_options,
)
def generate_gpu_compiler_spec(
) -> QnnExecuTorchBackendOptions:
"""
Helper function generating backend options for QNN GPU
Args:
Returns:
QnnExecuTorchHtpBackendOptions: backend options for QNN GPU.
"""
return QnnExecuTorchBackendOptions(
backend_type=QnnExecuTorchBackendType.kGpuBackend
)
def generate_dsp_compiler_spec(
) -> QnnExecuTorchBackendOptions:
"""
Helper function generating backend options for QNN DSP
Args:
Returns:
QnnExecuTorchHtpBackendOptions: backend options for QNN DSP.
"""
return QnnExecuTorchBackendOptions(
backend_type=QnnExecuTorchBackendType.kDspBackend
)
'''
import torch
import os
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a4w_qnn_ptq_config,
get_default_16bit_qnn_ptq_config,
get_default_8bit_qnn_ptq_config,
QnnQuantizer,
QuantDtype,
)
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
QcomChipset,
)
from executorch.backends.qualcomm.utils.utils import (
capture_program,
generate_htp_compiler_spec,
generate_gpu_compiler_spec,
generate_dsp_compiler_spec,
generate_qnn_executorch_compiler_spec,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
@unique
class QnnExecuTorchHtpPerformanceMode(IntEnum):
kHtpDefault = 0
kHtpSustainedHighPerformance = 1
kHtpBurst = 2
kHtpHighPerformance = 3
kHtpPowerSaver = 4
kHtpLowPowerSaver = 5
kHtpHighPowerSaver = 6
kHtpLowBalanced = 7
kHtpBalanced = 8
def make_quantizer(
quant_dtype: Optional[QuantDtype],
custom_annotations=(),
per_channel_conv=True,
per_channel_linear=False,
act_observer=MovingAverageMinMaxObserver,
):
quantizer = QnnQuantizer()
quantizer.add_custom_quant_annotations(custom_annotations)
quantizer.set_per_channel_conv_quant(per_channel_conv)
quantizer.set_per_channel_linear_quant(per_channel_linear)
if quant_dtype == QuantDtype.use_8a8w:
quantizer.set_bit8_op_quant_config(
get_default_8bit_qnn_ptq_config(act_observer=act_observer)
)
elif quant_dtype == QuantDtype.use_16a16w:
quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS)
quantizer.set_bit16_op_quant_config(
get_default_16bit_qnn_ptq_config(act_observer=act_observer)
)
elif quant_dtype == QuantDtype.use_16a4w:
quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS)
quantizer.set_bit16_op_quant_config(
get_16a4w_qnn_ptq_config(act_observer=act_observer)
)
quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4")
else:
raise AssertionError(f"No support for QuantDtype {quant_dtype}.")
return quantizer
def build_executorch_binary(
model, # noqa: B006
inputs, # noqa: B006
soc_model,
file_name,
chip="CPU",
performance_mode=QnnExecuTorchHtpPerformanceMode.kHtpBurst,
custom_annotations=(),
skip_node_id_set=None,
skip_node_op_set=None,
quant_dtype: Optional[QuantDtype] = None,
per_channel_linear=False, # TODO: remove this once QNN fully supports linear
shared_buffer=False,
metadata=None,
act_observer=MovingAverageMinMaxObserver,
dump_intermediate_outputs=False,
):
if quant_dtype is not None:
quantizer = make_quantizer(
quant_dtype=quant_dtype,
custom_annotations=custom_annotations,
per_channel_conv=True,
per_channel_linear=per_channel_linear,
act_observer=act_observer,
)
captured_model = torch.export.export(model, inputs).module()
annotated_model = prepare_pt2e(captured_model, quantizer)
print("Quantizing the model...")
# calibration
annotated_model(*inputs)
quantized_model = convert_pt2e(annotated_model)
edge_prog = capture_program(quantized_model, inputs)
else:
edge_prog = capture_program(model, inputs)
if chip == "GPU":
backend_options = generate_gpu_compiler_spec()
elif chip == "HTP":
backend_options = generate_htp_compiler_spec(
use_fp16=False if quant_dtype else True,
performance_mode=performance_mode
)
elif chip == "DSP":
backend_options = generate_dsp_compiler_spec()
else:
raise ValueError(f"Unknown chip: {chip}")
qnn_partitioner = QnnPartitioner(
generate_qnn_executorch_compiler_spec(
soc_model=getattr(QcomChipset, soc_model),
backend_options=backend_options,
shared_buffer=shared_buffer,
dump_intermediate_outputs=dump_intermediate_outputs,
),
skip_node_id_set,
skip_node_op_set,
)
executorch_config = ExecutorchBackendConfig(
# For shared buffer, user must pass the memory address
# which is allocated by RPC memory to executor runner.
# Therefore, won't want to pre-allocate
# by memory manager in runtime.
memory_planning_pass=MemoryPlanningPass(
alloc_graph_input=not shared_buffer,
alloc_graph_output=not shared_buffer,
),
)
if metadata is None:
exported_program = to_backend(edge_prog.exported_program, qnn_partitioner)
exported_program.graph_module.graph.print_tabular()
exec_prog = to_edge(exported_program).to_executorch(config=executorch_config)
with open(f"{file_name}.pte", "wb") as file:
file.write(exec_prog.buffer)
else:
edge_prog_mgr = EdgeProgramManager(
edge_programs={"forward": edge_prog.exported_program},
constant_methods=metadata,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner)
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
with open(f"{file_name}.pte", "wb") as file:
file.write(exec_prog_mgr.buffer)
Versions
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 7745HX with Radeon Graphics
CPU family: 25
Model: 97
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
BogoMIPS: 7186.32
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Virtualization: AMD-V
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] No relevant packages
[conda] No relevant packages
Metadata
Metadata
Assignees
Labels
Type
Projects
Status