33import time
44import numpy as np
55import array
6- # import torch
7- # from torch.nn.functional import pad
8- # from vllm import LLM, AsyncLLMEngine, AsyncEngineArgs, SamplingParams
9- # from vllm.inputs import TokensPrompt
10- # from transformers import AutoProcessor
11- # import pickle
126import time
137import threading
14- # import tqdm
158import queue
169
1710import logging
18- # from typing import TYPE_CHECKING, Optional, List
19- # from pathlib import Path
2011
2112import mlperf_loadgen as lg
2213from dataset import Dataset
23- import sys
24- import subprocess
25- import requests
26- from contextlib import suppress
27- import signal
28- from openai import OpenAI , AsyncOpenAI
14+ from openai import AsyncOpenAI
2915from typing import Any , Dict
3016
3117
3218logging .basicConfig (level = logging .INFO )
3319log = logging .getLogger ("Qwen2.5-VL-7B" )
3420
3521# ---------- Config ----------
36- MODEL = os .environ .get ("VLLM_MODEL" , "Qwen/Qwen2.5-VL-7B-Instruct" )
3722HOST = os .environ .get ("VLLM_HOST" , "vllm" )
3823PORT = int (os .environ .get ("VLLM_PORT" , "8000" ))
39-
40- # Extra vLLM server args if you need them (GPU/CPU flags, trust-remote-code, tensor-parallel-size, etc.)
41- EXTRA_ARGS = os .environ .get ("VLLM_EXTRA_ARGS" , "--trust-remote-code" ).split ()
42-
4324BASE_URL = f"http://{ HOST } :{ PORT } /v1"
44- HEALTH_URLS = [
45- f"http://{ HOST } :{ PORT } /health" , # preferred if available
46- f"http://{ HOST } :{ PORT } /v1/models" , # fallback readiness check
47- ]
4825
4926class SUT :
5027 def __init__ (
@@ -59,9 +36,8 @@ def __init__(
5936 # session was killed partway through
6037 workers = 1 ,
6138 tensor_parallel_size = 8 ,
62- _load_model = False
39+ scenario = "offline"
6340 ):
64- self .proc = None
6541 self .model_path = model_path or f"Qwen/Qwen2.5-VL-7B-Instruct"
6642
6743 if not batch_size :
@@ -83,27 +59,27 @@ def __init__(
8359 self .data_object .UnloadSamplesFromRam ,
8460 )
8561
86- if _load_model : self .load_model ()
87- gen_kwargs = {
62+ self .num_workers = workers
63+ self . params = {
8864 "temperature" : 0.0 ,
89- "top_p" : 1 ,
90- "top_k" : 1 ,
91- "seed" : 42 ,
9265 "max_tokens" : 1024 ,
9366 }
94- self .max_tokens = 1024
95- self .temperature = 0.0
96- # self.sampling_params = SamplingParams(**gen_kwargs)
97- self .sampling_params = gen_kwargs
98- # self.sampling_params.all_stop_token_ids.add(self.model.get_tokenizer().eos_token_id)
67+
68+ if scenario == "offline" :
69+ from vllm import SamplingParams
70+ from transformers import AutoProcessor
9971
100- self .num_workers = workers
101- self .worker_threads = [None ] * self .num_workers
102- self .query_queue = queue .Queue ()
72+ self .load_model ()
73+ self .sampling_params = SamplingParams (** self .params )
74+ self .processor = AutoProcessor .from_pretrained (self .model_path )
75+ self .request_id_counter = 0
76+
77+ self .worker_threads = [None ] * self .num_workers
78+ self .query_queue = queue .Queue ()
10379
104- self .use_cached_outputs = use_cached_outputs
105- self .sample_counter = 0
106- self .sample_counter_lock = threading .Lock ()
80+ self .use_cached_outputs = use_cached_outputs
81+ self .sample_counter = 0
82+ self .sample_counter_lock = threading .Lock ()
10783
10884 def start (self ):
10985 # Create worker threads
@@ -122,24 +98,37 @@ def stop(self):
12298 def process_queries (self ):
12399 """Processor of the queued queries. User may choose to add batching logic"""
124100 while True :
125- qitem = self .query_queue .get ()
126- if qitem is None :
101+ qitems = self .query_queue .get ()
102+ if qitems is None :
127103 break
128104
129- query_ids = [q .index for q in qitem ]
105+ query_ids = [q .index for q in qitems ]
130106
131107 tik1 = time .time ()
132108
133- input_ids_tensor = [
134- self .data_object .input_ids [q .index ] for q in qitem ]
135- # input_text_tensor = [
136- # self.data_object.input[q.index] for q in qitem]
137- # for in_text in input_text_tensor:
138- # log.info(f"Input: {in_text}")
139-
109+ prompts = []
110+ for item in qitems :
111+ question = self .data_object .prompts [item .index ]
112+
113+ placeholders = [{"type" : "image_url" , "image_url" : {"url" : f"data:image/png;base64,{ b64img } " }} for b64img in self .data_object .images [item .index ]]
114+ messages = [
115+ {"role" : "system" , "content" : "You are a helpful assistant." },
116+ {"role" : "user" , "content" : [* placeholders , {"type" : "text" , "text" : question }]},
117+ ]
118+
119+ prompt = self .processor .apply_chat_template (
120+ messages , tokenize = False , add_generation_prompt = True
121+ )
122+ prompts .append ({
123+ "prompt" : prompt ,
124+ "multi_modal_data" : {"image" : self .data_object .images [item .index ]}
125+ })
126+
127+
128+
140129 tik2 = time .time ()
141130 outputs = self .model .generate (
142- prompt_token_ids = input_ids_tensor , sampling_params = self .sampling_params
131+ prompts = prompts , sampling_params = self .sampling_params
143132 )
144133 pred_output_tokens = []
145134 for output in outputs :
@@ -151,14 +140,14 @@ def process_queries(self):
151140 pred_output_tokens ,
152141 query_id_list = query_ids ,
153142 )
154- for i in range (len (qitem )):
143+ for i in range (len (qitems )):
155144 n_tokens = processed_output [i ].shape [0 ]
156145 response_array = array .array (
157146 "B" , processed_output [i ].tobytes ())
158147 bi = response_array .buffer_info ()
159148 response = [
160149 lg .QuerySampleResponse (
161- qitem [i ].id ,
150+ qitems [i ].id ,
162151 bi [0 ],
163152 bi [1 ],
164153 n_tokens )]
@@ -167,7 +156,7 @@ def process_queries(self):
167156 tok = time .time ()
168157
169158 with self .sample_counter_lock :
170- self .sample_counter += len (qitem )
159+ self .sample_counter += len (qitems )
171160 log .info (f"Samples run: { self .sample_counter } " )
172161 if tik1 :
173162 log .info (f"\t BatchMaker time: { tik2 - tik1 } " )
@@ -176,12 +165,13 @@ def process_queries(self):
176165 log .info (f"\t ==== Total time: { tok - tik1 } " )
177166
178167 def load_model (self ):
168+ from vllm import LLM
179169 log .info ("Loading model..." )
180- # self.model = LLM(
181- # self.model_path,
182- # dtype=self.dtype,
183- # tensor_parallel_size=self.tensor_parallel_size,
184- # )
170+ self .model = LLM (
171+ self .model_path ,
172+ dtype = self .dtype ,
173+ tensor_parallel_size = self .tensor_parallel_size ,
174+ )
185175 log .info ("Loaded model" )
186176
187177 def get_sut (self ):
@@ -209,74 +199,6 @@ def issue_queries(self, query_samples):
209199 def flush_queries (self ):
210200 pass
211201
212- def start_vllm_server (self , model : str , host : str , port : int , extra_args : list [str ]) -> subprocess .Popen :
213- """
214- Launch vLLM's OpenAI-compatible server as a subprocess.
215- Returns a Popen handle.
216- """
217- cmd = [
218- sys .executable , "-m" , "vllm.entrypoints.openai.api_server" ,
219- "--model" , model ,
220- "--host" , host ,
221- "--port" , str (port ),
222- ] + extra_args
223-
224- # Inherit stdout/stderr so you can see logs in your terminal.
225- # (If you prefer, redirect to a file or PIPE.)
226- print ("Launching vLLM server:\n " , " " .join (cmd ))
227- proc = subprocess .Popen (cmd , stdout = sys .stdout , stderr = sys .stderr )
228- return proc
229-
230- def wait_until_ready (self , timeout_s : int = 300 , poll_interval_s : float = 0.5 ) -> None :
231- """
232- Poll one or more health endpoints until HTTP 200, or raise on timeout.
233- """
234- start = time .time ()
235- last_err = None
236- while time .time () - start < timeout_s :
237- for url in HEALTH_URLS :
238- try :
239- r = requests .get (url , timeout = 2 )
240- if r .status_code == 200 :
241- # For /v1/models, ensure JSON is present (indicates API is fully up)
242- if url .endswith ("/v1/models" ):
243- with suppress (Exception ):
244- _ = r .json ()
245- print (f"Server ready at { url } " )
246- return
247- except Exception as e :
248- last_err = e
249- time .sleep (poll_interval_s )
250-
251- raise TimeoutError (f"vLLM server didn't become ready within { timeout_s } s. Last error: { last_err } " )
252-
253-
254- def terminate_process (self , proc : subprocess .Popen , grace_s : int = 15 ) -> None :
255- """
256- Try to stop the server gracefully, then force-kill if needed.
257- Cross-platform friendly.
258- """
259- if proc .poll () is not None :
260- return # already exited
261-
262- try :
263- # POSIX: try SIGINT first (clean shutdown), then SIGTERM, then SIGKILL.
264- proc .send_signal (signal .SIGINT )
265- try :
266- proc .wait (timeout = grace_s )
267- return
268- except subprocess .TimeoutExpired :
269- pass
270-
271- proc .terminate ()
272- try :
273- proc .wait (timeout = 5 )
274- except subprocess .TimeoutExpired :
275- proc .kill ()
276- finally :
277- with suppress (Exception ):
278- proc .wait (timeout = 2 )
279-
280202
281203 def __del__ (self ):
282204 pass
@@ -285,51 +207,32 @@ def __del__(self):
285207class SUTServer (SUT ):
286208 def __init__ (
287209 self ,
288- # ... same arguments as before
289210 model_path = None ,
290211 dtype = "bfloat16" ,
291212 total_sample_count = 13368 ,
292213 dataset_path = None ,
293214 batch_size = None ,
294215 workers = 1 ,
295- tensor_parallel_size = 8
216+ tensor_parallel_size = 8 ,
217+ scenario = "offline"
296218 ):
297- # We call a modified super().__init__ that doesn't load the model yet
298- # because model loading needs to be async.
299- # This is a bit of a simplification; you might need to adjust the base SUT init.
300- # For this example, let's assume the base init can be called without loading the model.
301219 super ().__init__ (
302220 model_path = model_path ,
221+ batch_size = batch_size ,
303222 dtype = dtype ,
304223 total_sample_count = total_sample_count ,
305224 dataset_path = dataset_path ,
306225 workers = workers ,
307226 tensor_parallel_size = tensor_parallel_size ,
308- # Add a flag to skip model loading in the base class constructor
309- _load_model = False
227+ scenario = scenario
310228 )
311- self .request_id_counter = 0
312- client = AsyncOpenAI (
229+ self ._client = AsyncOpenAI (
313230 base_url = BASE_URL ,
314231 api_key = "EMPTY"
315232 )
316- self ._client = client
317- # This will be the single, long-running asyncio event loop
318- self .event_loop = None
319- self .event_loop_thread = None
320-
321- # We'll use an asyncio.Queue to communicate between the issue_queries thread
322- # and our main async event loop.
323- self .async_query_queue = None
324233
325234
326235 def start (self ):
327- # self.proc = self.start_vllm_server(MODEL, HOST, PORT, EXTRA_ARGS)
328- # self.wait_until_ready()
329-
330- # # Optional: print the models list to confirm we're talking to the right thing
331- # r = requests.get(f"{BASE_URL}/models", timeout=3)
332- # print("\n[Models]", r.json())
333236 pass
334237
335238
@@ -350,20 +253,15 @@ async def _issue_one(
350253
351254 messages = [{"role" : "user" , "content" : contents }]
352255
353- params = dict (
354- model = self .model_path ,
355- max_tokens = self .max_tokens ,
356- temperature = self .temperature
357- )
358-
359256 async with semaphore :
360257 ttft_set = False
361258
362259 # await the async creation; ask for a streaming iterator
363260 stream = await self ._client .chat .completions .create (
364261 stream = True ,
365262 messages = messages ,
366- ** params
263+ model = self .model_path ,
264+ ** self .params
367265 )
368266 out = []
369267 # iterate asynchronously
@@ -416,8 +314,4 @@ def issue_queries(self, query_samples):
416314 asyncio .run (self ._issue_queries_async (query_samples ))
417315
418316 def stop (self ):
419- # if self.proc is not None:
420- # print("\nShutting down vLLM server…")
421- # self.terminate_process(self.proc)
422- # print("Done.")
423317 pass
0 commit comments