|
8 | 8 | import json
|
9 | 9 | import os
|
10 | 10 | import subprocess
|
| 11 | +import xml.etree.ElementTree as ET |
11 | 12 | from dataclasses import dataclass
|
12 | 13 | from typing import Any, Dict, List, Optional, Tuple, Union
|
13 | 14 |
|
14 | 15 | import requests
|
15 | 16 | import torch
|
| 17 | +import yaml |
16 | 18 | from huggingface_hub import login, snapshot_download
|
17 | 19 | from requests.exceptions import HTTPError
|
18 | 20 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
19 | 21 |
|
20 |
| -from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants |
| 22 | +from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants |
21 | 23 | from QEfficient.utils.logging_utils import logger
|
22 | 24 |
|
23 | 25 |
|
@@ -442,3 +444,113 @@ class IOInfo:
|
442 | 444 |
|
443 | 445 | def __repr__(self):
|
444 | 446 | return f"input_name:{self.name}\tdatatype:{self.datatype}\tshape:{self.shape}"
|
| 447 | + |
| 448 | + |
| 449 | +def dump_qconfig(func): |
| 450 | + def wrapper(self, *args, **kwargs): |
| 451 | + result = func(self, *args, **kwargs) |
| 452 | + create_and_dump_qconfigs( |
| 453 | + self.qpc_path, |
| 454 | + self.onnx_path, |
| 455 | + self.get_model_config, |
| 456 | + [cls.__name__ for cls in self._pytorch_transforms], |
| 457 | + [cls.__name__ for cls in self._onnx_transforms], |
| 458 | + kwargs.get("specializations"), |
| 459 | + kwargs.get("mdp_ts_num_devices", 1), |
| 460 | + kwargs.get("num_speculative_tokens"), |
| 461 | + **{ |
| 462 | + k: v |
| 463 | + for k, v in kwargs.items() |
| 464 | + if k not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io"] |
| 465 | + }, |
| 466 | + ) |
| 467 | + return result |
| 468 | + |
| 469 | + return wrapper |
| 470 | + |
| 471 | + |
| 472 | +def create_and_dump_qconfigs( |
| 473 | + qpc_path, |
| 474 | + onnx_path, |
| 475 | + huggingface_config, |
| 476 | + pytorch_transforms, |
| 477 | + onnx_transforms, |
| 478 | + specializations, |
| 479 | + mdp_ts_num_devices, |
| 480 | + num_speculative_tokens, |
| 481 | + **compiler_options, |
| 482 | +): |
| 483 | + """ |
| 484 | + This Method creates a JSON file which contains all the configs for a model. |
| 485 | + Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and |
| 486 | + many other compilation options. |
| 487 | + """ |
| 488 | + qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None |
| 489 | + enable_qnn = True if "qnn_config" in compiler_options else None |
| 490 | + |
| 491 | + qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json") |
| 492 | + onnx_path = str(onnx_path) |
| 493 | + specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json")) |
| 494 | + compile_dir = str(os.path.dirname(qpc_path)) |
| 495 | + qnn_config_path = ( |
| 496 | + (qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None |
| 497 | + ) |
| 498 | + |
| 499 | + # Extract QAIC SDK Apps Version from SDK XML file |
| 500 | + tree = ET.parse(Constants.SDK_APPS_XML) |
| 501 | + root = tree.getroot() |
| 502 | + qaic_version = root.find(".//base_version").text |
| 503 | + |
| 504 | + # Extract QNN SDK details from YAML file if the environment variable is set |
| 505 | + qnn_sdk_details = None |
| 506 | + qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) |
| 507 | + if qnn_sdk_path: |
| 508 | + qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) |
| 509 | + with open(qnn_sdk_yaml_path, "r") as file: |
| 510 | + qnn_sdk_details = yaml.safe_load(file) |
| 511 | + |
| 512 | + # Ensure all objects in the configs dictionary are JSON serializable |
| 513 | + def make_serializable(obj): |
| 514 | + if isinstance(obj, (int, float, str, bool, type(None))): |
| 515 | + return obj |
| 516 | + elif isinstance(obj, (list, tuple)): |
| 517 | + return [make_serializable(item) for item in obj] |
| 518 | + elif isinstance(obj, dict): |
| 519 | + return {key: make_serializable(value) for key, value in obj.items()} |
| 520 | + elif hasattr(obj, "__dict__"): |
| 521 | + return make_serializable(vars(obj)) |
| 522 | + return str(obj) |
| 523 | + |
| 524 | + qconfigs = { |
| 525 | + "huggingface_config": make_serializable(huggingface_config), |
| 526 | + "qpc_config": { |
| 527 | + "QEff_config": { |
| 528 | + "pytorch_transforms": make_serializable(pytorch_transforms), |
| 529 | + "onnx_transforms": make_serializable(onnx_transforms), |
| 530 | + "onnx_path": onnx_path, |
| 531 | + }, |
| 532 | + }, |
| 533 | + } |
| 534 | + |
| 535 | + aic_compiler_config = { |
| 536 | + "apps_sdk_version": qaic_version, |
| 537 | + "compile_dir": compile_dir, |
| 538 | + "specializations_file_path": specializations_file_path, |
| 539 | + "specializations": make_serializable(specializations), |
| 540 | + "mdp_ts_num_devices": mdp_ts_num_devices, |
| 541 | + "num_speculative_tokens": num_speculative_tokens, |
| 542 | + **compiler_options, |
| 543 | + } |
| 544 | + qnn_config = { |
| 545 | + "enable_qnn": enable_qnn, |
| 546 | + "qnn_config_path": qnn_config_path, |
| 547 | + } |
| 548 | + # Put AIC or qnn details. |
| 549 | + if enable_qnn: |
| 550 | + qconfigs["qpc_config"]["qnn_config"] = qnn_config |
| 551 | + if qnn_sdk_details: |
| 552 | + qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details) |
| 553 | + else: |
| 554 | + qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config |
| 555 | + |
| 556 | + create_json(qconfig_file_path, qconfigs) |
0 commit comments