Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@

def get_inputs(
self,
prompts: Union[list[str], list[torch.Tensor]],
prompts: Union[list[str], list[torch.Tensor], list[int]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
Expand All @@ -826,11 +826,14 @@
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio

text_prompt_kwargs = {
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
prompt,
"multi_modal_data": multi_modal_data or None
}
text_prompt_kwargs = {"multi_modal_data": multi_modal_data or None}
if isinstance(prompt, str):
text_prompt_kwargs["prompt"] = prompt

Check failure on line 831 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "str", target has type "dict[str, Any | list[Any]] | None") [assignment]

Check failure on line 831 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "str", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 831 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "str", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 831 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "str", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 831 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "str", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]
elif isinstance(prompt, list) and isinstance(prompt[0], int):
text_prompt_kwargs["prompt_token_ids"] = prompt

Check failure on line 833 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[Any]", target has type "dict[str, Any | list[Any]] | None") [assignment]

Check failure on line 833 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[Any]", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 833 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[Any]", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 833 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[Any]", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]

Check failure on line 833 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[Any]", target has type "Optional[dict[str, Union[Any, list[Any]]]]") [assignment]
else:
text_prompt_kwargs["prompt_embeds"] = prompt

inputs.append(TextPrompt(**text_prompt_kwargs))

return inputs
Expand Down
3 changes: 0 additions & 3 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ async def test_basic_audio(mary_had_lamb, model_name):
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS

# TODO(PATRICK) - REMOVE AFTER RELEASE
return # skip for now

# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
Expand Down
115 changes: 115 additions & 0 deletions tests/models/multimodal/generation/test_voxtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
TextChunk, UserMessage)

from vllm.transformers_utils.tokenizer import MistralTokenizer

from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test

MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
]


@pytest.fixture()
def server(request, audio_assets: AudioTestAssets):
args = [
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}),
] + MISTRAL_FORMAT_ARGS

with RemoteOpenAIServer(MODEL_NAME,
args,
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
"30"}) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


def _get_prompt(audio_assets, question):
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)

audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(len(audio_assets))
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]

text_chunk = TextChunk(text=question)
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]

return tokenizer.apply_chat_template(messages=messages)


@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_with_multiple_audios(vllm_runner,
audio_assets: AudioTestAssets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
run_multi_audio_test(
vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate
for audio in audio_assets])],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tokenizer_mode="mistral",
)


@pytest.mark.asyncio
async def test_online_serving(client, audio_assets: AudioTestAssets):
"""Exercises online serving with/without chunked prefill enabled."""

def asset_to_chunk(asset):
audio = Audio.from_file(str(asset.get_local_path()), strict=False)
audio.format = "wav"
audio_dict = AudioChunk.from_audio(audio).to_openai()
return audio_dict

audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
messages = [{
"role":
"user",
"content": [
*audio_chunks,
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]

chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)

assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def check_available_online(
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", tokenizer_mode="mistral"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501

# [Cross-encoder]
Expand Down