Skip to content

Commit 7cfea38

Browse files
Code for SDK configs Inclusion (#298)
This will create a config JSON file, which contains all the details about compilation and SDK versions. Currently, this code is added in the code block of QEFFAutoModelForCausalLM.compile. The config would look like below: ``` { "huggingface_config": { "vocab_size": 50257, "n_positions": 1024, "n_embd": 768, "n_layer": 12, "n_head": 12, "n_inner": null, "activation_function": "gelu_new", "resid_pdrop": 0.1, "embd_pdrop": 0.1, "attn_pdrop": 0.1, "layer_norm_epsilon": 1e-05, "initializer_range": 0.02, "summary_type": "cls_index", "summary_use_proj": true, "summary_activation": null, "summary_first_dropout": 0.1, "summary_proj_to_labels": true, "scale_attn_weights": true, "use_cache": true, "scale_attn_by_inverse_layer_idx": false, "reorder_and_upcast_attn": false, "bos_token_id": 50256, "eos_token_id": 50256, "return_dict": true, "output_hidden_states": false, "output_attentions": false, "torchscript": false, "torch_dtype": null, "use_bfloat16": false, "tf_legacy_loss": false, "pruned_heads": {}, "tie_word_embeddings": true, "chunk_size_feed_forward": 0, "is_encoder_decoder": false, "is_decoder": false, "cross_attention_hidden_size": null, "add_cross_attention": false, "tie_encoder_decoder": false, "max_length": 20, "min_length": 0, "do_sample": false, "early_stopping": false, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 1.0, "top_k": 50, "top_p": 1.0, "typical_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": null, "num_return_sequences": 1, "output_scores": false, "return_dict_in_generate": false, "forced_bos_token_id": null, "forced_eos_token_id": null, "remove_invalid_values": false, "exponential_decay_length_penalty": null, "suppress_tokens": null, "begin_suppress_tokens": null, "architectures": [ "GPT2LMHeadModel" ], "finetuning_task": null, "id2label": { "0": "LABEL_0", "1": "LABEL_1" }, "label2id": { "LABEL_0": 0, "LABEL_1": 1 }, "tokenizer_class": null, "prefix": null, "pad_token_id": null, "sep_token_id": null, "decoder_start_token_id": null, "task_specific_params": { "text-generation": { "do_sample": true, "max_length": 50 } }, "problem_type": null, "_name_or_path": "gpt2", "_commit_hash": "607a30d783dfa663caf39e06633721c8d4cfcd7e", "_attn_implementation_internal": "eager", "transformers_version": null, "model_type": "gpt2", "n_ctx": 1024 }, "qpc_config": { "QEff_config": { "pytorch_transforms": [ "AwqToMatmulNbitsTransform", "GPTQToMatmulNbitsTransform", "CustomOpsTransform", "KVCacheTransform" ], "onnx_transforms": [ "FP16ClipTransform", "SplitTensorsTransform" ], "onnx_path": "/root/.cache/qeff_models/GPT2LMHeadModel-36f0eca92731bb47/GPT2LMHeadModel.onnx" }, "aic_compiler_config": { "apps_sdk_version": "1.20.0", "compile_dir": "/root/.cache/qeff_models/GPT2LMHeadModel-36f0eca92731bb47", "specializtions_file_path": "/root/.cache/qeff_models/GPT2LMHeadModel-36f0eca92731bb47/specializations.json", "prefill_seq_len": 32, "ctx_len": 128, "batch_size": 1, "full_batch_size": null, "num_devices": 1, "num_cores": 16, "mxfp6_matmul": false, "mxint8_kv_cache": false, "num_speculative_tokens": null }, "qnn_config": { "enable_qnn": true, "qnn_config_path": "QEfficient/compile/qnn_config.json", "product": "QAIRT", "os": { "Ubuntu": 22.04, "Windows": 11 }, "sdk_flavor": [ "aic" ], "version": "2.31.0", "build_id": "250109072054_3882", "qnn_backend_api_version": "2.18.0", "tensorflow": "2.10.1", "tflite": "2.3.0", "torch": "1.13.1", "onnx": "1.16.1", "onnxruntime": "1.17.1", "onnxsimplifier": "0.4.36", "android-ndk": "r26c", "platform": "AIC.1.20.0.14" } } } ``` Note: The code structure may change. --------- Signed-off-by: Abukhoyer Shaik <[email protected]> Co-authored-by: Abukhoyer Shaik <[email protected]>
1 parent 0ea70b4 commit 7cfea38

18 files changed

+185
-6
lines changed

QEfficient/base/modeling_qeff.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from QEfficient.base.pytorch_transforms import PytorchTransform
2424
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2525
from QEfficient.generation.cloud_infer import QAICInferenceSession
26-
from QEfficient.utils import constants
26+
from QEfficient.utils import constants, dump_qconfig
2727
from QEfficient.utils._utils import load_json
2828
from QEfficient.utils.cache import QEFF_HOME, to_hashable
2929

@@ -212,6 +212,7 @@ def _export(
212212
self.onnx_path = onnx_path
213213
return onnx_path
214214

215+
@dump_qconfig
215216
def _compile(
216217
self,
217218
onnx_path: Optional[str] = None,
@@ -337,8 +338,10 @@ def _compile(
337338
)
338339

339340
self.qpc_path = qpc_path
341+
340342
return qpc_path
341343

344+
@dump_qconfig
342345
def _qnn_compile(
343346
self,
344347
onnx_path: Optional[str] = None,
@@ -436,4 +439,5 @@ def _qnn_compile(
436439
)
437440

438441
self.qpc_path = qpc_path
442+
439443
return qpc_path

QEfficient/peft/auto.py

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def model_hash(self) -> str:
107107
mhash = mhash.hexdigest()[:16]
108108
return mhash
109109

110+
@property
111+
def get_model_config(self) -> dict:
112+
return self.model.get_base_model().config.__dict__
113+
110114
def load_adapter(self, model_id: str, adapter_name: str):
111115
"""Loads a new adapter from huggingface hub or local path
112116

QEfficient/peft/lora/auto.py

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def model_hash(self) -> str:
9090
mhash = mhash.hexdigest()[:16]
9191
return mhash
9292

93+
@property
94+
def get_model_config(self) -> dict:
95+
return self.model.model.config.__dict__
96+
9397
def download_adapter(
9498
self,
9599
adapter_model_id: str,

QEfficient/transformers/models/modeling_auto.py

+24
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ def model_hash(self) -> str:
229229
mhash = mhash.hexdigest()[:16]
230230
return mhash
231231

232+
@property
233+
def get_model_config(self) -> dict:
234+
return self.model.config.__dict__
235+
232236
def export(self, export_dir: Optional[str] = None) -> str:
233237
"""
234238
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
@@ -447,6 +451,10 @@ def model_name(self) -> str:
447451
mname = mname[4:]
448452
return mname
449453

454+
@property
455+
def get_model_config(self) -> dict:
456+
return self.model.model.vision_model.config.__dict__
457+
450458

451459
class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
452460
_pytorch_transforms = [
@@ -506,6 +514,10 @@ def model_name(self) -> str:
506514
mname = mname[4:]
507515
return mname
508516

517+
@property
518+
def get_model_config(self) -> dict:
519+
return self.model.language_model.config.__dict__
520+
509521

510522
class _QEffAutoModelForImageTextToTextDualQPC:
511523
_hf_auto_class = AutoModelForImageTextToText
@@ -1128,6 +1140,10 @@ def model_name(self) -> str:
11281140
mname = mname[4:]
11291141
return mname
11301142

1143+
@property
1144+
def get_model_config(self) -> dict:
1145+
return self.model.config.__dict__
1146+
11311147

11321148
class QEFFAutoModelForImageTextToText:
11331149
"""
@@ -1332,6 +1348,10 @@ def model_hash(self) -> str:
13321348
mhash = mhash.hexdigest()[:16]
13331349
return mhash
13341350

1351+
@property
1352+
def get_model_config(self) -> dict:
1353+
return self.model.config.__dict__
1354+
13351355
def export(self, export_dir: Optional[str] = None) -> str:
13361356
"""
13371357
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
@@ -1642,6 +1662,10 @@ def model_hash(self) -> str:
16421662
mhash = mhash.hexdigest()[:16]
16431663
return mhash
16441664

1665+
@property
1666+
def get_model_config(self) -> dict:
1667+
return self.model.config.__dict__
1668+
16451669
def export(self, export_dir: Optional[str] = None) -> str:
16461670
"""
16471671
Exports the model to ``ONNX`` format using ``torch.onnx.export``.

QEfficient/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from QEfficient.utils._utils import ( # noqa: F401
1313
check_and_assign_cache_dir,
14+
dump_qconfig,
1415
get_num_layers_from_config,
1516
get_onnx_dir_name,
1617
get_padding_shape_from_config,

QEfficient/utils/_utils.py

+113-1
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
import json
99
import os
1010
import subprocess
11+
import xml.etree.ElementTree as ET
1112
from dataclasses import dataclass
1213
from typing import Any, Dict, List, Optional, Tuple, Union
1314

1415
import requests
1516
import torch
17+
import yaml
1618
from huggingface_hub import login, snapshot_download
1719
from requests.exceptions import HTTPError
1820
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
1921

20-
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
22+
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants
2123
from QEfficient.utils.logging_utils import logger
2224

2325

@@ -442,3 +444,113 @@ class IOInfo:
442444

443445
def __repr__(self):
444446
return f"input_name:{self.name}\tdatatype:{self.datatype}\tshape:{self.shape}"
447+
448+
449+
def dump_qconfig(func):
450+
def wrapper(self, *args, **kwargs):
451+
result = func(self, *args, **kwargs)
452+
create_and_dump_qconfigs(
453+
self.qpc_path,
454+
self.onnx_path,
455+
self.get_model_config,
456+
[cls.__name__ for cls in self._pytorch_transforms],
457+
[cls.__name__ for cls in self._onnx_transforms],
458+
kwargs.get("specializations"),
459+
kwargs.get("mdp_ts_num_devices", 1),
460+
kwargs.get("num_speculative_tokens"),
461+
**{
462+
k: v
463+
for k, v in kwargs.items()
464+
if k not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io"]
465+
},
466+
)
467+
return result
468+
469+
return wrapper
470+
471+
472+
def create_and_dump_qconfigs(
473+
qpc_path,
474+
onnx_path,
475+
huggingface_config,
476+
pytorch_transforms,
477+
onnx_transforms,
478+
specializations,
479+
mdp_ts_num_devices,
480+
num_speculative_tokens,
481+
**compiler_options,
482+
):
483+
"""
484+
This Method creates a JSON file which contains all the configs for a model.
485+
Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and
486+
many other compilation options.
487+
"""
488+
qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None
489+
enable_qnn = True if "qnn_config" in compiler_options else None
490+
491+
qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json")
492+
onnx_path = str(onnx_path)
493+
specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json"))
494+
compile_dir = str(os.path.dirname(qpc_path))
495+
qnn_config_path = (
496+
(qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None
497+
)
498+
499+
# Extract QAIC SDK Apps Version from SDK XML file
500+
tree = ET.parse(Constants.SDK_APPS_XML)
501+
root = tree.getroot()
502+
qaic_version = root.find(".//base_version").text
503+
504+
# Extract QNN SDK details from YAML file if the environment variable is set
505+
qnn_sdk_details = None
506+
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
507+
if qnn_sdk_path:
508+
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
509+
with open(qnn_sdk_yaml_path, "r") as file:
510+
qnn_sdk_details = yaml.safe_load(file)
511+
512+
# Ensure all objects in the configs dictionary are JSON serializable
513+
def make_serializable(obj):
514+
if isinstance(obj, (int, float, str, bool, type(None))):
515+
return obj
516+
elif isinstance(obj, (list, tuple)):
517+
return [make_serializable(item) for item in obj]
518+
elif isinstance(obj, dict):
519+
return {key: make_serializable(value) for key, value in obj.items()}
520+
elif hasattr(obj, "__dict__"):
521+
return make_serializable(vars(obj))
522+
return str(obj)
523+
524+
qconfigs = {
525+
"huggingface_config": make_serializable(huggingface_config),
526+
"qpc_config": {
527+
"QEff_config": {
528+
"pytorch_transforms": make_serializable(pytorch_transforms),
529+
"onnx_transforms": make_serializable(onnx_transforms),
530+
"onnx_path": onnx_path,
531+
},
532+
},
533+
}
534+
535+
aic_compiler_config = {
536+
"apps_sdk_version": qaic_version,
537+
"compile_dir": compile_dir,
538+
"specializations_file_path": specializations_file_path,
539+
"specializations": make_serializable(specializations),
540+
"mdp_ts_num_devices": mdp_ts_num_devices,
541+
"num_speculative_tokens": num_speculative_tokens,
542+
**compiler_options,
543+
}
544+
qnn_config = {
545+
"enable_qnn": enable_qnn,
546+
"qnn_config_path": qnn_config_path,
547+
}
548+
# Put AIC or qnn details.
549+
if enable_qnn:
550+
qconfigs["qpc_config"]["qnn_config"] = qnn_config
551+
if qnn_sdk_details:
552+
qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details)
553+
else:
554+
qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config
555+
556+
create_json(qconfig_file_path, qconfigs)

QEfficient/utils/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,14 @@ class Constants:
7575
MAX_QPC_LIMIT = 30
7676
MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
7777
NUM_SPECULATIVE_TOKENS = 2
78+
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version.
7879

7980

8081
@dataclass
8182
class QnnConstants:
8283
# QNN PATH to be read from environment variable.
8384
QNN_SDK_PATH_ENV_VAR_NAME = "QNN_SDK_ROOT"
85+
QNN_SDK_YAML = "sdk.yaml"
8486

8587
# QNN Compilation tools
8688
QAIRT_CONVERTER = "{}/bin/{}/qairt-converter"

tests/peft/lora/test_lora_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import os
79
from pathlib import Path
810
from time import perf_counter
911

@@ -225,6 +227,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
225227
# test compile
226228
qeff_model.compile(prefill_seq_len=32, ctx_len=64)
227229
assert Path(qeff_model.qpc_path).is_dir()
230+
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
228231

229232
# test generate
230233
prompts = ["hello!", "hi", "hello, my name is", "hey"]
@@ -249,6 +252,7 @@ def test_auto_lora_model_for_causal_lm_cb_compile_generate(base_model_name, adap
249252
# test compile
250253
qeff_model.compile(prefill_seq_len=32, ctx_len=64, full_batch_size=2)
251254
assert Path(qeff_model.qpc_path).is_dir()
255+
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
252256

253257
# test generate
254258
prompts = ["hello!", "hi", "hello, my name is", "hey"]

tests/peft/test_peft_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
89
from time import perf_counter
910

1011
import numpy as np
@@ -187,3 +188,4 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con
187188
end = perf_counter()
188189
compile_time_1 = end - start
189190
assert compile_time_1 < 0.01 * compile_time_0
191+
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))

tests/qnn_tests/test_causal_lm_models_qnn.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
9+
810
import numpy as np
911
import pytest
1012
from transformers import AutoModelForCausalLM
@@ -98,14 +100,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
98100
if not get_available_device_id():
99101
pytest.skip("No available devices to run model on Cloud AI 100")
100102

101-
_ = qeff_model.compile(
103+
qpc_path = qeff_model.compile(
102104
prefill_seq_len=prompt_len,
103105
ctx_len=ctx_len,
104106
num_cores=14,
105107
mxfp6=False,
106108
aic_enable_depth_first=False,
107109
enable_qnn=True,
108110
)
111+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
109112
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
110113
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
111114
gen_len = ort_tokens.shape[-1]
@@ -136,7 +139,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
136139
if not get_available_device_id():
137140
pytest.skip("No available devices to run model on Cloud AI 100")
138141

139-
_ = qeff_model.compile(
142+
qpc_path = qeff_model.compile(
140143
prefill_seq_len=prompt_len,
141144
ctx_len=ctx_len,
142145
num_cores=14,
@@ -145,6 +148,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
145148
full_batch_size=full_batch_size,
146149
enable_qnn=True,
147150
)
151+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
148152
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
149153

150154
assert all(

tests/text_generation/test_text_generation.py

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
9+
810
import pytest
911
from transformers import AutoModelForCausalLM
1012

@@ -101,3 +103,4 @@ def test_generate_text_stream(
101103
assert cloud_ai_100_output == stream_tokens, (
102104
f"Deviation in output observed while comparing regular execution and streamed output: {cloud_ai_100_output} != {stream_tokens}"
103105
)
106+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))

0 commit comments

Comments
 (0)