Skip to content

Commit ad06845

Browse files
committed
Addressing comments
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 1b3043d commit ad06845

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

QEfficient/base/common.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
"""
1414

1515
import importlib
16-
from collections import OrderedDict
1716
from typing import Any
1817

1918
import transformers.models.auto.modeling_auto as mapping
2019
from transformers import AutoConfig
2120

2221
from QEfficient.base.modeling_qeff import QEFFBaseModel
2322

24-
MODEL_CLASS_MAPPING = OrderedDict(
25-
[
26-
(tuple(mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()), "QEFFAutoModelForCausalLM"),
27-
(tuple(mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()), "QEFFAutoModelForImageTextToText"),
28-
]
29-
)
23+
MODEL_CLASS_MAPPING = {}
24+
for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
25+
MODEL_CLASS_MAPPING[architecture] = "QEFFAutoModelForCausalLM"
26+
27+
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
28+
MODEL_CLASS_MAPPING[architecture] = "QEFFAutoModelForImageTextToText"
3029

3130

3231
class QEFFCommonLoader:
@@ -50,13 +49,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
5049
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
5150
architecture = config.architectures[0] if config.architectures else None
5251

53-
model_class = None
54-
for key_tuple, class_name in MODEL_CLASS_MAPPING.items():
55-
if architecture in key_tuple:
56-
module = importlib.import_module("QEfficient.transformers.models.modeling_auto")
57-
model_class = getattr(module, class_name)
58-
break
59-
if model_class is None:
52+
class_name = MODEL_CLASS_MAPPING.get(architecture)
53+
if class_name:
54+
module = importlib.import_module("QEfficient.transformers.models.modeling_auto")
55+
model_class = getattr(module, class_name)
56+
else:
6057
raise NotImplementedError(
6158
f"Unknown architecture={architecture}, either use specific auto model class for loading the model or raise an issue for support!"
6259
)

QEfficient/cloud/infer.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import requests
1414
from PIL import Image
15-
from transformers import AutoConfig, AutoProcessor, TextStreamer
15+
from transformers import AutoProcessor, TextStreamer
1616
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
1717

1818
from QEfficient.base.common import QEFFCommonLoader
@@ -121,16 +121,10 @@ def main(
121121
**kwargs,
122122
)
123123

124-
tokenizer = load_hf_tokenizer(
125-
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
126-
cache_dir=cache_dir,
127-
hf_token=hf_token,
128-
)
129-
130124
#########
131125
# Execute
132126
#########
133-
config = AutoConfig.from_pretrained(model_name)
127+
config = qeff_model.model.config
134128
architecture = config.architectures[0] if config.architectures else None
135129

136130
if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
@@ -166,13 +160,19 @@ def main(
166160
add_special_tokens=False,
167161
)
168162
streamer = TextStreamer(processor.tokenizer)
169-
_ = qeff_model.generate(
163+
output = qeff_model.generate(
170164
inputs=split_inputs,
171165
streamer=streamer,
172166
device_ids=device_group,
173167
generation_len=generation_len,
174168
)
169+
print(output)
175170
else:
171+
tokenizer = load_hf_tokenizer(
172+
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
173+
cache_dir=cache_dir,
174+
hf_token=hf_token,
175+
)
176176
_ = qeff_model.generate(
177177
tokenizer,
178178
prompts=prompt,

QEfficient/transformers/models/modeling_auto.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,6 @@ def kv_offload_generate(
823823
prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time
824824
),
825825
)
826-
827-
print(exec_info)
828826
return exec_info
829827

830828

@@ -1116,8 +1114,6 @@ def cloud_ai_100_generate(
11161114
prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time
11171115
),
11181116
)
1119-
1120-
print(exec_info)
11211117
return exec_info
11221118

11231119
@property
@@ -1492,7 +1488,7 @@ def compile(
14921488
specializations.append(decode_specialization)
14931489

14941490
if compiler_options.pop("img_size", None):
1495-
logger.warning("img_size is not a valid argument for Text-to-Text Model.")
1491+
logger.warning(f"Skipping img_size as it is not a valid argument for {self.model.config.architectures[0]}.")
14961492

14971493
if enable_qnn:
14981494
if compiler_options:

0 commit comments

Comments
 (0)