Skip to content

Changing the hashing methodology for cache folder creation of models. #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
92 changes: 57 additions & 35 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# ----------------------------------------------------------------------------

import hashlib
import copy
import inspect
import logging
import shutil
Expand All @@ -22,8 +22,16 @@
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
from QEfficient.utils.cache import QEFF_HOME, to_hashable
from QEfficient.utils import (
constants,
create_json,
dump_qconfig,
filter_and_create_export_hash,
generate_mdp_partition_config,
hash_compile_params,
load_json,
)
from QEfficient.utils.cache import QEFF_HOME

logger = logging.getLogger(__name__)

Expand All @@ -45,12 +53,24 @@ class QEFFBaseModel(ABC):
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module) -> None:
def create_model_params(self, **kwargs) -> Dict:
model_params = copy.deepcopy(kwargs)

model_params["config"] = self.model.config.to_diff_dict()
model_params["_transform_names"] = self._transform_names()
return model_params

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = self.create_model_params(**kwargs)

if hasattr(self.model.config, "architectures"):
self.model_architecture = self.model.config.architectures[0]
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)

# Apply the transformations
any_transformed = False
Expand All @@ -67,10 +87,6 @@ def __init__(self, model: torch.nn.Module) -> None:
@abstractmethod
def model_name(self) -> str: ...

@property
@abstractmethod
def model_hash(self) -> str: ...

@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
"""
Expand Down Expand Up @@ -134,8 +150,18 @@ def _export(
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
"""
export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)

export_dir = Path(export_dir or (QEFF_HOME / self.model_architecture / self.model_name))
export_hash, filtered_hash_params = filter_and_create_export_hash(
model_params=self.hash_params,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_kwargs=export_kwargs,
onnx_transform_kwargs=onnx_transform_kwargs,
export_dir=export_dir,
)

export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
onnx_path = export_dir / f"{self.model_name}.onnx"
if onnx_path.is_file():
self.onnx_path = onnx_path
Expand Down Expand Up @@ -210,6 +236,11 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

# Dump JSON file with hashed parameters
hashed_params_export_path = export_dir / "hashed_model_params.json"
create_json(hashed_params_export_path, filtered_hash_params)
logger.info("Hashed parameters exported successfully.")

self.onnx_path = onnx_path
return onnx_path

Expand Down Expand Up @@ -240,12 +271,10 @@ def _compile(
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.``
:compiler_options: Pass any compiler option as input.
Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
:compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
self.export()
Expand All @@ -257,11 +286,6 @@ def _compile(
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")

if enable_qnn:
if compiler_options:
logger.warning(
f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}"
)

self.qpc_path = qnn_compile(
onnx_path=onnx_path,
qpc_base_path=compile_dir,
Expand Down Expand Up @@ -299,24 +323,17 @@ def _compile(
else:
mdp_ts_json = None

compile_hash = hashlib.sha256(to_hashable(command))

if specializations is not None:
compile_hash.update(to_hashable(specializations))

if custom_io is not None:
compile_hash.update(to_hashable(custom_io))

if num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))

# Hash the MDP partition config and the number of devices.
compile_hash.update(to_hashable(mdp_ts_json))
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))

# Check if already compiled
compile_hash = compile_hash.hexdigest()[:16]
compile_hash, hashed_params = hash_compile_params(
command=command,
specializations=specializations,
custom_io=custom_io,
mdp_ts_num_devices=mdp_ts_num_devices,
mdp_ts_json=mdp_ts_json,
num_speculative_tokens=num_speculative_tokens,
)
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)

qpc_path = compile_dir / "qpc"
qpc_path.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -354,6 +371,7 @@ def _compile(
logger.info(f"Running compiler: {' '.join(command)}")
try:
subprocess.run(command, capture_output=True, check=True)

except subprocess.CalledProcessError as e:
raise RuntimeError(
"\n".join(
Expand All @@ -367,6 +385,10 @@ def _compile(
)
)

# Dump JSON file with hashed parameters
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
create_json(hashed_compile_params_path, hashed_params)
logger.info("Hashed parameters exported successfully.")
self.qpc_path = qpc_path

return qpc_path
Loading
Loading