Skip to content

Commit 03d9871

Browse files
authored
Added Model mapping for swiftKV model to __init__.py (#413)
`__init__.py` was calling modelling_utils leading to creation of `MODEL_CLASS_MAPPING` before init. As a result, registered swiftKV model's config class and QEff Class was not present in the `MODEL_CLASS_MAPPING` and hence giving error of unknown architecture. Updated this `__init__` to be initialized first and then all other classes by removing the modelling_utils call and adding model_class_mapping in __init__. --------- Signed-off-by: Asmita Goswami <[email protected]> Signed-off-by: Asmita Goswami <[email protected]>
1 parent 4ec01a5 commit 03d9871

File tree

4 files changed

+42
-29
lines changed

4 files changed

+42
-29
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,10 @@
1212
# hf_transfer is imported (will happen on line 15 via leading imports)
1313
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
1414

15-
from transformers import AutoConfig
16-
17-
from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS
15+
# Placeholder for all non-transformer models registered in QEfficient
16+
import QEfficient.utils.model_registery # noqa: F401
1817
from QEfficient.utils.logging_utils import logger
1918

20-
# loop over all the model types which are not present in transformers and register them
21-
for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items():
22-
# Register the model config class based on the model type. This will be first element in the tuple
23-
AutoConfig.register(model_type, model_cls[0])
24-
25-
# Register the non transformer library Class and config class using AutoModelClass
26-
model_cls[2].register(model_cls[0], model_cls[1])
27-
2819

2920
def check_qaic_sdk():
3021
"""Check if QAIC SDK is installed"""

QEfficient/base/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4141
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
4242
"""
4343
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
44-
architecture = config.architectures[0] if config.architectures else None
4544

46-
class_name = MODEL_CLASS_MAPPING.get(architecture)
45+
class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
4746
if class_name:
4847
module = __import__("QEfficient.transformers.models.modeling_auto")
4948
model_class = getattr(module, class_name)
5049
else:
5150
raise NotImplementedError(
52-
f"Unknown architecture={architecture}, either use specific auto model class for loading the model or raise an issue for support!"
51+
f"Unknown architecture={config.__class__.__name__}, either use specific auto model class for loading the model or raise an issue for support!"
5352
)
5453

5554
local_model_dir = kwargs.pop("local_model_dir", None)

QEfficient/transformers/modeling_utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212
import torch.nn as nn
1313
import transformers.models.auto.modeling_auto as mapping
14-
from transformers import AutoModelForCausalLM
1514
from transformers.models.codegen.modeling_codegen import (
1615
CodeGenAttention,
1716
CodeGenBlock,
@@ -91,11 +90,6 @@
9190
from QEfficient.customop import CustomRMSNormAIC
9291

9392
# Placeholder for all non-transformer models
94-
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import (
95-
QEffLlamaSwiftKVConfig,
96-
QEffLlamaSwiftKVForCausalLM,
97-
)
98-
9993
from .models.codegen.modeling_codegen import (
10094
QEffCodeGenAttention,
10195
QeffCodeGenBlock,
@@ -279,18 +273,19 @@
279273
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
280274
}
281275

282-
# Map of model type to config class, Modelling class and transformer model architecture class
283-
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
284-
"llama_swiftkv": [QEffLlamaSwiftKVConfig, QEffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
285-
}
276+
277+
def build_model_class_mapping(auto_model_class, qeff_class_name):
278+
"""
279+
Build a mapping of model config class names to QEfficient model class names.
280+
"""
281+
return {
282+
config_class.__name__: qeff_class_name for config_class, model_class in auto_model_class._model_mapping.items()
283+
}
286284

287285

288286
MODEL_CLASS_MAPPING = {
289-
**{architecture: "QEFFAutoModelForCausalLM" for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()},
290-
**{
291-
architecture: "QEFFAutoModelForImageTextToText"
292-
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
293-
},
287+
**build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM"),
288+
**build_model_class_mapping(mapping.AutoModelForImageTextToText, "QEFFAutoModelForImageTextToText"),
294289
}
295290

296291

QEfficient/utils/model_registery.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
from transformers import AutoConfig, AutoModelForCausalLM
10+
11+
# Placeholder for all non-transformer models
12+
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import (
13+
QEffLlamaSwiftKVConfig,
14+
QEffLlamaSwiftKVForCausalLM,
15+
)
16+
17+
# Map of model type to config class, Modelling class and transformer model architecture class
18+
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
19+
"llama_swiftkv": [QEffLlamaSwiftKVConfig, QEffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
20+
}
21+
22+
# loop over all the model types which are not present in transformers and register them
23+
for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items():
24+
# Register the model config class based on the model type. This will be first element in the tuple
25+
AutoConfig.register(model_type, model_cls[0])
26+
27+
# Register the non transformer library Class and config class using AutoModelClass
28+
model_cls[2].register(model_cls[0], model_cls[1])

0 commit comments

Comments
 (0)