1010import sys
1111from typing import List , Optional
1212
13+ import requests
14+ from PIL import Image
15+ from transformers import AutoConfig , AutoProcessor , TextStreamer
16+ from transformers .models .auto .modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
17+
1318from QEfficient .base .common import QEFFCommonLoader
1419from QEfficient .utils import check_and_assign_cache_dir , load_hf_tokenizer
1520from QEfficient .utils .logging_utils import logger
@@ -36,6 +41,7 @@ def main(
3641 allow_mxint8_mdp_io : bool = False ,
3742 enable_qnn : Optional [bool ] = False ,
3843 qnn_config : Optional [str ] = None ,
44+ img_size : Optional [int ] = None ,
3945 ** kwargs ,
4046) -> None :
4147 """
@@ -65,18 +71,16 @@ def main(
6571 :allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
6672 :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
6773 :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
74+ :kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
75+ -allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
76+ -qpc_crc=True -> -qpc-crc
6877
6978 .. code-block:: bash
7079
7180 python -m QEfficient.cloud.infer OPTIONS
7281
7382 """
7483 cache_dir = check_and_assign_cache_dir (local_model_dir , cache_dir )
75- tokenizer = load_hf_tokenizer (
76- pretrained_model_name_or_path = (local_model_dir if local_model_dir else model_name ),
77- cache_dir = cache_dir ,
78- hf_token = hf_token ,
79- )
8084
8185 if "--mxfp6" in sys .argv :
8286 if args .mxfp6 :
@@ -85,6 +89,9 @@ def main(
8589 if args .mxint8 :
8690 logger .warning ("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead." )
8791
92+ image_path = kwargs .pop ("image_path" , None )
93+ image_url = kwargs .pop ("image_url" , None )
94+
8895 qeff_model = QEFFCommonLoader .from_pretrained (
8996 pretrained_model_name_or_path = model_name ,
9097 cache_dir = cache_dir ,
@@ -110,20 +117,70 @@ def main(
110117 allow_mxint8_mdp_io = allow_mxint8_mdp_io ,
111118 enable_qnn = enable_qnn ,
112119 qnn_config = qnn_config ,
120+ img_size = img_size ,
113121 ** kwargs ,
114122 )
115123
124+ tokenizer = load_hf_tokenizer (
125+ pretrained_model_name_or_path = (local_model_dir if local_model_dir else model_name ),
126+ cache_dir = cache_dir ,
127+ hf_token = hf_token ,
128+ )
129+
116130 #########
117131 # Execute
118132 #########
119- _ = qeff_model .generate (
120- tokenizer ,
121- prompts = prompt ,
122- device_id = device_group ,
123- prompt = prompt ,
124- prompts_txt_file_path = prompts_txt_file_path ,
125- generation_len = generation_len ,
126- )
133+ config = AutoConfig .from_pretrained (model_name )
134+ architecture = config .architectures [0 ] if config .architectures else None
135+
136+ if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES .values ():
137+ processor = AutoProcessor .from_pretrained (model_name , use_fast = False )
138+
139+ raw_image = None
140+ if image_url is not None :
141+ raw_image = Image .open (requests .get (image_url , stream = True ).raw )
142+ elif image_path is not None :
143+ raw_image = Image .open (image_path )
144+ else :
145+ raise FileNotFoundError (
146+ 'Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"'
147+ )
148+
149+ conversation = [
150+ {
151+ "role" : "user" ,
152+ "content" : [
153+ {"type" : "image" },
154+ {"type" : "text" , "text" : prompt [0 ]}, # Currently accepting only 1 prompt
155+ ],
156+ },
157+ ]
158+
159+ # Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
160+ input_text = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
161+
162+ split_inputs = processor (
163+ text = input_text ,
164+ images = raw_image ,
165+ return_tensors = "pt" ,
166+ add_special_tokens = False ,
167+ )
168+ streamer = TextStreamer (processor .tokenizer )
169+ _ = qeff_model .generate (
170+ inputs = split_inputs ,
171+ streamer = streamer ,
172+ device_ids = device_group ,
173+ generation_len = generation_len ,
174+ )
175+ else :
176+ _ = qeff_model .generate (
177+ tokenizer ,
178+ prompts = prompt ,
179+ device_id = device_group ,
180+ prompt = prompt ,
181+ prompts_txt_file_path = prompts_txt_file_path ,
182+ generation_len = generation_len ,
183+ )
127184
128185
129186if __name__ == "__main__" :
@@ -226,10 +283,11 @@ def main(
226283 Sample Config: QEfficient/compile/qnn_config.json" ,
227284 )
228285 parser .add_argument (
229- "qnn_config" ,
286+ "-- qnn_config" ,
230287 nargs = "?" ,
231288 type = str ,
232289 )
290+ parser .add_argument ("--img-size" , "--img_size" , default = None , type = int , required = False , help = "Size of Image" )
233291
234292 args , compiler_options = parser .parse_known_args ()
235293 compiler_options_dict = {}
0 commit comments