Skip to content

Commit 25b1329

Browse files
committed
Revert "refactor vllm/inputs/data.py to use newly defined functions"
This reverts commit 5986992. Signed-off-by: Tobias Pitters <[email protected]>
1 parent 5986992 commit 25b1329

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

vllm/inputs/data.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class SingletonInputsAdapter:
268268
def prompt(self) -> Optional[str]:
269269
inputs = self.inputs
270270

271-
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
271+
if inputs["type"] == "token" or inputs["type"] == "multimodal":
272272
return inputs.get("prompt")
273273

274274
assert_never(inputs) # type: ignore[arg-type]
@@ -277,7 +277,7 @@ def prompt(self) -> Optional[str]:
277277
def prompt_token_ids(self) -> List[int]:
278278
inputs = self.inputs
279279

280-
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
280+
if inputs["type"] == "token" or inputs["type"] == "multimodal":
281281
return inputs.get("prompt_token_ids", [])
282282

283283
assert_never(inputs) # type: ignore[arg-type]
@@ -286,7 +286,7 @@ def prompt_token_ids(self) -> List[int]:
286286
def token_type_ids(self) -> List[int]:
287287
inputs = self.inputs
288288

289-
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
289+
if inputs["type"] == "token" or inputs["type"] == "multimodal":
290290
return inputs.get("token_type_ids", [])
291291

292292
assert_never(inputs) # type: ignore[arg-type]
@@ -295,7 +295,7 @@ def token_type_ids(self) -> List[int]:
295295
def prompt_embeds(self) -> Optional[torch.Tensor]:
296296
inputs = self.inputs
297297

298-
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
298+
if inputs["type"] == "token" or inputs["type"] == "multimodal":
299299
return None
300300

301301
assert_never(inputs) # type: ignore[arg-type]
@@ -304,9 +304,10 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
304304
def multi_modal_data(self) -> "MultiModalDataDict":
305305
inputs = self.inputs
306306

307-
if is_token_inputs(inputs):
307+
if inputs["type"] == "token":
308308
return inputs.get("multi_modal_data", {})
309-
elif is_multimodal_inputs(inputs):
309+
310+
if inputs["type"] == "multimodal":
310311
return inputs.get("mm_kwargs", {})
311312

312313
assert_never(inputs) # type: ignore[arg-type]
@@ -315,9 +316,10 @@ def multi_modal_data(self) -> "MultiModalDataDict":
315316
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
316317
inputs = self.inputs
317318

318-
if is_token_inputs(inputs):
319+
if inputs["type"] == "token":
319320
return inputs.get("multi_modal_inputs", {})
320-
elif is_multimodal_inputs(inputs):
321+
322+
if inputs["type"] == "multimodal":
321323
return inputs.get("mm_kwargs", {})
322324

323325
assert_never(inputs) # type: ignore[arg-type]
@@ -329,6 +331,7 @@ def multi_modal_hashes(self) -> List[str]:
329331
if is_token_inputs(inputs):
330332
return inputs.get("multi_modal_hashes", [])
331333
elif is_multimodal_inputs(inputs):
334+
# only the case when we use MultiModalInputsV2
332335
return inputs.get("mm_hashes", [])
333336

334337
assert_never(inputs) # type: ignore[arg-type]
@@ -337,9 +340,10 @@ def multi_modal_hashes(self) -> List[str]:
337340
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
338341
inputs = self.inputs
339342

340-
if is_token_inputs(inputs):
343+
if inputs["type"] == "token":
341344
return inputs.get("multi_modal_placeholders", {})
342-
elif is_multimodal_inputs(inputs):
345+
346+
if inputs["type"] == "multimodal":
343347
return inputs.get("mm_placeholders", {})
344348

345349
assert_never(inputs) # type: ignore[arg-type]
@@ -348,9 +352,10 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
348352
def mm_processor_kwargs(self) -> Dict[str, Any]:
349353
inputs = self.inputs
350354

351-
if is_token_inputs(inputs):
355+
if inputs["type"] == "token":
352356
return inputs.get("mm_processor_kwargs", {})
353-
elif is_multimodal_inputs(inputs):
357+
358+
if inputs["type"] == "multimodal":
354359
return {}
355360

356361
assert_never(inputs) # type: ignore[arg-type]

0 commit comments

Comments
 (0)