Skip to content

Commit b6087a6

Browse files
authored
[mypy] Pass type checking in vllm/inputs (#11680)
Signed-off-by: Tobias Pitters <[email protected]>
1 parent 23c1b10 commit b6087a6

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

tools/mypy.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ run_mypy vllm/compilation
2323
run_mypy vllm/distributed
2424
run_mypy vllm/engine
2525
run_mypy vllm/executor
26+
run_mypy vllm/inputs
2627
run_mypy vllm/lora
2728
run_mypy vllm/model_executor
2829
run_mypy vllm/plugins

vllm/inputs/data.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def prompt(self) -> Optional[str]:
250250
if inputs["type"] == "token" or inputs["type"] == "multimodal":
251251
return inputs.get("prompt")
252252

253-
assert_never(inputs)
253+
assert_never(inputs) # type: ignore[arg-type]
254254

255255
@cached_property
256256
def prompt_token_ids(self) -> List[int]:
@@ -259,7 +259,7 @@ def prompt_token_ids(self) -> List[int]:
259259
if inputs["type"] == "token" or inputs["type"] == "multimodal":
260260
return inputs.get("prompt_token_ids", [])
261261

262-
assert_never(inputs)
262+
assert_never(inputs) # type: ignore[arg-type]
263263

264264
@cached_property
265265
def token_type_ids(self) -> List[int]:
@@ -268,7 +268,7 @@ def token_type_ids(self) -> List[int]:
268268
if inputs["type"] == "token" or inputs["type"] == "multimodal":
269269
return inputs.get("token_type_ids", [])
270270

271-
assert_never(inputs)
271+
assert_never(inputs) # type: ignore[arg-type]
272272

273273
@cached_property
274274
def prompt_embeds(self) -> Optional[torch.Tensor]:
@@ -277,7 +277,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
277277
if inputs["type"] == "token" or inputs["type"] == "multimodal":
278278
return None
279279

280-
assert_never(inputs)
280+
assert_never(inputs) # type: ignore[arg-type]
281281

282282
@cached_property
283283
def multi_modal_data(self) -> "MultiModalDataDict":
@@ -289,7 +289,7 @@ def multi_modal_data(self) -> "MultiModalDataDict":
289289
if inputs["type"] == "multimodal":
290290
return inputs.get("mm_kwargs", {})
291291

292-
assert_never(inputs)
292+
assert_never(inputs) # type: ignore[arg-type]
293293

294294
@cached_property
295295
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
@@ -301,7 +301,7 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
301301
if inputs["type"] == "multimodal":
302302
return inputs.get("mm_kwargs", {})
303303

304-
assert_never(inputs)
304+
assert_never(inputs) # type: ignore[arg-type]
305305

306306
@cached_property
307307
def multi_modal_hashes(self) -> List[str]:
@@ -311,9 +311,10 @@ def multi_modal_hashes(self) -> List[str]:
311311
return inputs.get("multi_modal_hashes", [])
312312

313313
if inputs["type"] == "multimodal":
314-
return inputs.get("mm_hashes", [])
314+
# only the case when we use MultiModalInputsV2
315+
return inputs.get("mm_hashes", []) # type: ignore[return-value]
315316

316-
assert_never(inputs)
317+
assert_never(inputs) # type: ignore[arg-type]
317318

318319
@cached_property
319320
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
@@ -325,7 +326,7 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
325326
if inputs["type"] == "multimodal":
326327
return inputs.get("mm_placeholders", {})
327328

328-
assert_never(inputs)
329+
assert_never(inputs) # type: ignore[arg-type]
329330

330331
@cached_property
331332
def mm_processor_kwargs(self) -> Dict[str, Any]:
@@ -337,7 +338,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]:
337338
if inputs["type"] == "multimodal":
338339
return {}
339340

340-
assert_never(inputs)
341+
assert_never(inputs) # type: ignore[arg-type]
341342

342343

343344
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]

vllm/inputs/preprocess.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def _build_enc_dec_llm_inputs(
436436
or encoder_inputs["type"] == "multimodal"):
437437
pass
438438
else:
439-
assert_never(encoder_inputs)
439+
assert_never(encoder_inputs) # type: ignore[arg-type]
440440

441441
if decoder_inputs is None:
442442
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
@@ -452,7 +452,7 @@ def _build_enc_dec_llm_inputs(
452452
raise ValueError("Multi-modal decoder inputs of encoder-"
453453
"decoder models are not supported yet")
454454
else:
455-
assert_never(encoder_inputs)
455+
assert_never(encoder_inputs) # type: ignore[arg-type]
456456

457457
return EncoderDecoderInputs(
458458
encoder=encoder_inputs,
@@ -569,7 +569,7 @@ def _build_decoder_only_llm_inputs(
569569
prompt_adapter_request=prompt_adapter_request,
570570
)
571571
else:
572-
assert_never(prompt_inputs)
572+
assert_never(prompt_inputs) # type: ignore[arg-type]
573573

574574
return prompt_inputs
575575

vllm/inputs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def _ensure_mm_kwargs(
419419
# Be more strict in V2
420420
assert "mm_kwargs" in inputs
421421
else:
422-
assert_never(inputs["type"])
422+
assert_never(inputs["type"]) # type: ignore[arg-type]
423423

424424
def process_input(self, model_config: "ModelConfig",
425425
inputs: ProcessorInputs) -> ProcessorInputs:

0 commit comments

Comments
 (0)