Skip to content

Commit a55e33b

Browse files
asmigoswshubhagr-quicabukhoy
authored
Enabled Infer CLI for VLM (#287)
Added support for enabling VLMs via CLI. Sample command: ```bash python -m QEfficient.cloud.infer --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --batch_size 1 --prompt_len 32 --ctx_len 512 --num_cores 16 --device_group [0] --prompt "Descrive the image?" --mos 1 --allocator_dealloc_delay 1 --image_url https://i.etsystatic.com/8155076/r/il/0825c2/1594869823/il_fullxfull.1594869823_5x0w.jpg ``` --------- Signed-off-by: Shubham Agrawal <[email protected]> Signed-off-by: Asmita Goswami <[email protected]> Signed-off-by: Abukhoyer Shaik <[email protected]> Co-authored-by: shubhagr-quic <[email protected]> Co-authored-by: Abukhoyer Shaik <[email protected]>
1 parent 54a9b6f commit a55e33b

File tree

8 files changed

+189
-27
lines changed

8 files changed

+189
-27
lines changed

QEfficient/base/common.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
from typing import Any
1717

1818
from transformers import AutoConfig
19-
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
2019

2120
from QEfficient.base.modeling_qeff import QEFFBaseModel
22-
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
21+
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
2322
from QEfficient.utils import login_and_download_hf_lm
2423

2524

@@ -44,8 +43,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4443
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
4544
architecture = config.architectures[0] if config.architectures else None
4645

47-
if architecture in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
48-
model_class = QEFFAutoModelForCausalLM
46+
class_name = MODEL_CLASS_MAPPING.get(architecture)
47+
if class_name:
48+
module = __import__("QEfficient.transformers.models.modeling_auto")
49+
model_class = getattr(module, class_name)
4950
else:
5051
raise NotImplementedError(
5152
f"Unknown architecture={architecture}, either use specific auto model class for loading the model or raise an issue for support!"

QEfficient/cloud/infer.py

+127-21
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,86 @@
1010
import sys
1111
from typing import List, Optional
1212

13+
import requests
14+
from PIL import Image
15+
from transformers import PreTrainedModel, TextStreamer
16+
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
17+
1318
from QEfficient.base.common import QEFFCommonLoader
14-
from QEfficient.utils import check_and_assign_cache_dir, load_hf_tokenizer
19+
from QEfficient.utils import check_and_assign_cache_dir, load_hf_processor, load_hf_tokenizer
1520
from QEfficient.utils.logging_utils import logger
1621

1722

23+
# TODO: Remove after adding support for VLM's compile and execute
24+
def execute_vlm_model(
25+
qeff_model: PreTrainedModel,
26+
model_name: str,
27+
image_url: str,
28+
image_path: str,
29+
prompt: Optional[str] = None, # type: ignore
30+
device_group: Optional[List[int]] = None,
31+
local_model_dir: Optional[str] = None,
32+
cache_dir: Optional[str] = None,
33+
hf_token: Optional[str] = None,
34+
generation_len: Optional[int] = None,
35+
):
36+
"""
37+
This method generates output by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
38+
``Mandatory`` Args:
39+
:qeff_model (PreTrainedModel): QEfficient model object.
40+
:model_name (str): Hugging Face Model Card name, Example: ``llava-hf/llava-1.5-7b-hf``
41+
:image_url (str): Image URL to be used for inference. ``Defaults to None.``
42+
:image_path (str): Image path to be used for inference. ``Defaults to None.``
43+
``Optional`` Args:
44+
:prompt (str): Sample prompt for the model text generation. ``Defaults to None.``
45+
:device_group (List[int]): Device Ids to be used for compilation. If ``len(device_group) > 1``, multiple Card setup is enabled. ``Defaults to None.``
46+
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
47+
:cache_dir (str): Cache dir where downloaded HuggingFace files are stored. ``Defaults to None.``
48+
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
49+
:generation_len (int): Number of tokens to be generated. ``Defaults to None.``
50+
Returns:
51+
:dict: Output from the ``AI_100`` runtime.
52+
"""
53+
if not (image_url or image_path):
54+
raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"')
55+
raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path)
56+
57+
processor = load_hf_processor(
58+
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
59+
cache_dir=cache_dir,
60+
hf_token=hf_token,
61+
)
62+
63+
# Added for QEff version 1.20 supported VLM models (mllama and llava)
64+
conversation = [
65+
{
66+
"role": "user",
67+
"content": [
68+
{"type": "image"},
69+
{"type": "text", "text": prompt[0]},
70+
],
71+
}
72+
]
73+
74+
# Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
75+
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
76+
77+
split_inputs = processor(
78+
text=input_text,
79+
images=raw_image,
80+
return_tensors="pt",
81+
add_special_tokens=False,
82+
)
83+
streamer = TextStreamer(processor.tokenizer)
84+
output = qeff_model.generate(
85+
inputs=split_inputs,
86+
streamer=streamer,
87+
device_ids=device_group,
88+
generation_len=generation_len,
89+
)
90+
return output
91+
92+
1893
def main(
1994
model_name: str,
2095
num_cores: int,
@@ -65,18 +140,16 @@ def main(
65140
:allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
66141
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
67142
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
143+
:kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
144+
-allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
145+
-qpc_crc=True -> -qpc-crc
68146
69147
.. code-block:: bash
70148
71149
python -m QEfficient.cloud.infer OPTIONS
72150
73151
"""
74152
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
75-
tokenizer = load_hf_tokenizer(
76-
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
77-
cache_dir=cache_dir,
78-
hf_token=hf_token,
79-
)
80153

81154
if "--mxfp6" in sys.argv:
82155
if args.mxfp6:
@@ -93,6 +166,17 @@ def main(
93166
local_model_dir=local_model_dir,
94167
)
95168

169+
image_path = kwargs.pop("image_path", None)
170+
image_url = kwargs.pop("image_url", None)
171+
172+
config = qeff_model.model.config
173+
architecture = config.architectures[0] if config.architectures else None
174+
175+
if architecture not in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() and (
176+
kwargs.pop("img_size", None) or image_path or image_url
177+
):
178+
logger.warning(f"Skipping image arguments as they are not valid for {architecture}")
179+
96180
#########
97181
# Compile
98182
#########
@@ -116,14 +200,34 @@ def main(
116200
#########
117201
# Execute
118202
#########
119-
_ = qeff_model.generate(
120-
tokenizer,
121-
prompts=prompt,
122-
device_id=device_group,
123-
prompt=prompt,
124-
prompts_txt_file_path=prompts_txt_file_path,
125-
generation_len=generation_len,
126-
)
203+
if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
204+
exec_info = execute_vlm_model(
205+
qeff_model=qeff_model,
206+
model_name=model_name,
207+
prompt=prompt,
208+
image_url=image_url,
209+
image_path=image_path,
210+
device_group=device_group,
211+
local_model_dir=local_model_dir,
212+
cache_dir=cache_dir,
213+
hf_token=hf_token,
214+
generation_len=generation_len,
215+
)
216+
print(exec_info)
217+
else:
218+
tokenizer = load_hf_tokenizer(
219+
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
220+
cache_dir=cache_dir,
221+
hf_token=hf_token,
222+
)
223+
_ = qeff_model.generate(
224+
tokenizer,
225+
prompts=prompt,
226+
device_id=device_group,
227+
prompt=prompt,
228+
prompts_txt_file_path=prompts_txt_file_path,
229+
generation_len=generation_len,
230+
)
127231

128232

129233
if __name__ == "__main__":
@@ -219,23 +323,25 @@ def main(
219323
parser.add_argument(
220324
"--enable_qnn",
221325
"--enable-qnn",
222-
action="store_true",
326+
nargs="?",
327+
const=True,
328+
type=str,
223329
default=False,
224330
help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
225331
If not provided, the default configuration will be used.\
226332
Sample Config: QEfficient/compile/qnn_config.json",
227333
)
228-
parser.add_argument(
229-
"qnn_config",
230-
nargs="?",
231-
type=str,
232-
)
233334

234335
args, compiler_options = parser.parse_known_args()
336+
337+
if isinstance(args.enable_qnn, str):
338+
args.qnn_config = args.enable_qnn
339+
args.enable_qnn = True
340+
235341
compiler_options_dict = {}
236342
for i in range(0, len(compiler_options)):
237343
if compiler_options[i].startswith("--"):
238-
key = compiler_options[i].lstrip("-")
344+
key = compiler_options[i].lstrip("-").replace("-", "_")
239345
value = (
240346
compiler_options[i + 1]
241347
if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-")

QEfficient/transformers/modeling_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.nn as nn
13+
import transformers.models.auto.modeling_auto as mapping
1314
from transformers import AutoModelForCausalLM
1415
from transformers.models.codegen.modeling_codegen import (
1516
CodeGenAttention,
@@ -284,6 +285,15 @@
284285
}
285286

286287

288+
MODEL_CLASS_MAPPING = {
289+
**{architecture: "QEFFAutoModelForCausalLM" for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()},
290+
**{
291+
architecture: "QEFFAutoModelForImageTextToText"
292+
for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
293+
},
294+
}
295+
296+
287297
def _prepare_cross_attention_mask(
288298
cross_attention_mask: torch.Tensor,
289299
num_vision_tokens: int,

QEfficient/transformers/models/modeling_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
12571257
if kwargs.get("low_cpu_mem_usage", None):
12581258
logger.warning("Updating low_cpu_mem_usage=False")
12591259

1260+
if kwargs.pop("continuous_batching", None):
1261+
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
1262+
12601263
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
12611264
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
12621265
return cls(model, kv_offload=kv_offload, **kwargs)

QEfficient/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_padding_shape_from_config,
1818
get_qpc_dir_path,
1919
hf_download,
20+
load_hf_processor,
2021
load_hf_tokenizer,
2122
login_and_download_hf_lm,
2223
onnx_exists,

QEfficient/utils/_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def create_and_dump_qconfigs(
510510
# Extract QNN SDK details from YAML file if the environment variable is set
511511
qnn_sdk_details = None
512512
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
513-
if qnn_sdk_path:
513+
if enable_qnn and qnn_sdk_path:
514514
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
515515
with open(qnn_sdk_yaml_path, "r") as file:
516516
qnn_sdk_details = yaml.safe_load(file)

scripts/Jenkinsfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ pipeline {
6969
}
7070
stage('CLI Tests') {
7171
steps {
72-
timeout(time: 15, unit: 'MINUTES') {
72+
timeout(time: 60, unit: 'MINUTES') {
7373
sh '''
7474
sudo docker exec ${BUILD_TAG} bash -c "
7575
source /qnn_sdk/bin/envsetup.sh &&

tests/cloud/test_infer_vlm.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import pytest
9+
10+
from QEfficient.cloud.infer import main as infer
11+
12+
13+
@pytest.mark.on_qaic
14+
@pytest.mark.cli
15+
@pytest.mark.multimodal
16+
@pytest.mark.usefixtures("clean_up_after_test")
17+
def test_vlm_cli(setup, mocker):
18+
ms = setup
19+
# Taking some values from setup fixture and assigning other's based on model's requirement.
20+
# For example, mxint8 is not required for VLM models, so assigning False.
21+
infer(
22+
model_name="llava-hf/llava-1.5-7b-hf",
23+
num_cores=ms.num_cores,
24+
prompt="Describe the image.",
25+
prompts_txt_file_path=None,
26+
aic_enable_depth_first=ms.aic_enable_depth_first,
27+
mos=ms.mos,
28+
batch_size=1,
29+
full_batch_size=None,
30+
prompt_len=1024,
31+
ctx_len=2048,
32+
generation_len=20,
33+
mxfp6=False,
34+
mxint8=False,
35+
local_model_dir=None,
36+
cache_dir=None,
37+
hf_token=ms.hf_token,
38+
enable_qnn=False,
39+
qnn_config=None,
40+
image_url="https://i.etsystatic.com/8155076/r/il/0825c2/1594869823/il_fullxfull.1594869823_5x0w.jpg",
41+
)

0 commit comments

Comments
 (0)