Skip to content

Commit 8d99a93

Browse files
committed
Addressing comments
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 463dea5 commit 8d99a93

File tree

6 files changed

+47
-48
lines changed

6 files changed

+47
-48
lines changed

QEfficient/base/common.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,11 @@
1414

1515
from typing import Any
1616

17-
import transformers.models.auto.modeling_auto as mapping
17+
from QEfficient.transformers.modeling_utils import model_class_mapping
1818
from transformers import AutoConfig
1919

2020
from QEfficient.base.modeling_qeff import QEFFBaseModel
2121

22-
MODEL_CLASS_MAPPING = {}
23-
for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
24-
MODEL_CLASS_MAPPING[architecture] = "QEFFAutoModelForCausalLM"
25-
26-
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
27-
MODEL_CLASS_MAPPING[architecture] = "QEFFAutoModelForImageTextToText"
28-
2922

3023
class QEFFCommonLoader:
3124
"""
@@ -48,7 +41,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4841
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
4942
architecture = config.architectures[0] if config.architectures else None
5043

51-
class_name = MODEL_CLASS_MAPPING.get(architecture)
44+
class_name = model_class_mapping.get(architecture)
5245
if class_name:
5346
module = __import__("QEfficient.transformers.models.modeling_auto")
5447
model_class = getattr(module, class_name)

QEfficient/cloud/infer.py

+24-33
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
1717

1818
from QEfficient.base.common import QEFFCommonLoader
19-
from QEfficient.utils import check_and_assign_cache_dir, load_hf_tokenizer
19+
from QEfficient.utils import check_and_assign_cache_dir, load_hf_tokenizer, constants
2020
from QEfficient.utils.logging_utils import logger
2121

2222

@@ -41,7 +41,6 @@ def main(
4141
allow_mxint8_mdp_io: bool = False,
4242
enable_qnn: Optional[bool] = False,
4343
qnn_config: Optional[str] = None,
44-
img_size: Optional[int] = None,
4544
**kwargs,
4645
) -> None:
4746
"""
@@ -89,9 +88,6 @@ def main(
8988
if args.mxint8:
9089
logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.")
9190

92-
image_path = kwargs.pop("image_path", None)
93-
image_url = kwargs.pop("image_url", None)
94-
9591
qeff_model = QEFFCommonLoader.from_pretrained(
9692
pretrained_model_name_or_path=model_name,
9793
cache_dir=cache_dir,
@@ -100,6 +96,16 @@ def main(
10096
local_model_dir=local_model_dir,
10197
)
10298

99+
image_path = kwargs.pop("image_path", None)
100+
image_url = kwargs.pop("image_url", None)
101+
102+
config = qeff_model.model.config
103+
architecture = config.architectures[0] if config.architectures else None
104+
if architecture not in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
105+
img_size = kwargs.pop("img_size", None)
106+
if img_size or image_path or image_url:
107+
logger.warning(f"Skipping image arguments as they are not valid for {architecture}")
108+
103109
#########
104110
# Compile
105111
#########
@@ -117,38 +123,21 @@ def main(
117123
allow_mxint8_mdp_io=allow_mxint8_mdp_io,
118124
enable_qnn=enable_qnn,
119125
qnn_config=qnn_config,
120-
img_size=img_size,
121126
**kwargs,
122127
)
123128

124129
#########
125130
# Execute
126131
#########
127-
config = qeff_model.model.config
128-
architecture = config.architectures[0] if config.architectures else None
129-
130132
if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
131133
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
132134

133-
raw_image = None
134-
if image_url is not None:
135-
raw_image = Image.open(requests.get(image_url, stream=True).raw)
136-
elif image_path is not None:
137-
raw_image = Image.open(image_path)
138-
else:
139-
raise FileNotFoundError(
140-
'Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"'
141-
)
135+
if not (image_url or image_path):
136+
raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"')
137+
raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path)
142138

143-
conversation = [
144-
{
145-
"role": "user",
146-
"content": [
147-
{"type": "image"},
148-
{"type": "text", "text": prompt[0]}, # Currently accepting only 1 prompt
149-
],
150-
},
151-
]
139+
conversation = constants.Constants.conversation
140+
conversation[0]["content"][1].update({"text": prompt[0]}) # Currently accepting only 1 prompt
152141

153142
# Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
154143
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
@@ -277,19 +266,21 @@ def main(
277266
"--enable_qnn",
278267
"--enable-qnn",
279268
action="store_true",
269+
nargs="?",
270+
const=True,
271+
type=str,
280272
default=False,
281273
help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
282274
If not provided, the default configuration will be used.\
283275
Sample Config: QEfficient/compile/qnn_config.json",
284276
)
285-
parser.add_argument(
286-
"--qnn_config",
287-
nargs="?",
288-
type=str,
289-
)
290-
parser.add_argument("--img-size", "--img_size", default=None, type=int, required=False, help="Size of Image")
291277

292278
args, compiler_options = parser.parse_known_args()
279+
280+
if isinstance(args.enable_qnn, str):
281+
args.qnn_config = args.enable_qnn
282+
args.enable_qnn = True
283+
293284
compiler_options_dict = {}
294285
for i in range(0, len(compiler_options)):
295286
if compiler_options[i].startswith("--"):

QEfficient/transformers/modeling_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from collections import namedtuple
99
from typing import Dict, Optional, Tuple, Type
1010

11+
import transformers.models.auto.modeling_auto as mapping
12+
1113
import torch
1214
import torch.nn as nn
1315
from transformers.models.codegen.modeling_codegen import (
@@ -272,6 +274,15 @@
272274
}
273275

274276

277+
model_class_mapping = {
278+
**{architecture: "QEFFAutoModelForCausalLM" for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()},
279+
**{
280+
architecture: "QEFFAutoModelForImageTextToText"
281+
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
282+
},
283+
}
284+
285+
275286
def _prepare_cross_attention_mask(
276287
cross_attention_mask: torch.Tensor,
277288
num_vision_tokens: int,

QEfficient/transformers/models/modeling_auto.py

-5
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,6 @@ def compile(
615615
)
616616

617617
output_names = self.model.get_output_names(kv_offload=True)
618-
vision_onnx_path = compiler_options.get("vision_onnx_path", None)
619-
lang_onnx_path = compiler_options.get("lang_onnx_path", None)
620618

621619
specializations, compiler_options = self.model.get_specializations(
622620
batch_size=batch_size,
@@ -1567,9 +1565,6 @@ def compile(
15671565
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
15681566
specializations.append(decode_specialization)
15691567

1570-
if compiler_options.pop("img_size", None):
1571-
logger.warning(f"Skipping img_size as it is not a valid argument for {self.model.config.architectures[0]}.")
1572-
15731568
if enable_qnn:
15741569
if compiler_options:
15751570
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")

QEfficient/utils/_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def create_and_dump_qconfigs(
504504
# Extract QNN SDK details from YAML file if the environment variable is set
505505
qnn_sdk_details = None
506506
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
507-
if qnn_sdk_path:
507+
if enable_qnn and qnn_sdk_path:
508508
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
509509
with open(qnn_sdk_yaml_path, "r") as file:
510510
qnn_sdk_details = yaml.safe_load(file)

QEfficient/utils/constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class Constants:
7676
MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
7777
NUM_SPECULATIVE_TOKENS = 2
7878
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version.
79+
conversation = [
80+
{
81+
"role": "user",
82+
"content": [
83+
{"type": "image"},
84+
{"type": "text"},
85+
],
86+
}
87+
]
7988

8089

8190
@dataclass

0 commit comments

Comments
 (0)