Skip to content
Merged

Voxtral #20970

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e886051
WIP
patrickvonplaten Jul 4, 2025
0ed1c04
WIP
patrickvonplaten Jul 4, 2025
868aa9d
WIP
patrickvonplaten Jul 6, 2025
2f072ff
WIP
patrickvonplaten Jul 6, 2025
966f21b
WIP
patrickvonplaten Jul 6, 2025
955ea03
WIP
patrickvonplaten Jul 7, 2025
a0e6ccd
WIP
patrickvonplaten Jul 7, 2025
175ca9b
WIP
patrickvonplaten Jul 7, 2025
9b63b9e
Merge branch 'vllm-project:main' into add_voxtral
patrickvonplaten Jul 7, 2025
ac7317f
WIP
patrickvonplaten Jul 7, 2025
61cfc09
WIP
patrickvonplaten Jul 8, 2025
af188d8
WIP
patrickvonplaten Jul 8, 2025
9cd95c2
WIP
patrickvonplaten Jul 8, 2025
5727fbc
WIP
patrickvonplaten Jul 8, 2025
218424d
WIP
patrickvonplaten Jul 14, 2025
2942922
WIP
patrickvonplaten Jul 14, 2025
d48b2e0
WIP
patrickvonplaten Jul 14, 2025
c0dc455
WIP
patrickvonplaten Jul 14, 2025
9d60d04
WIP
patrickvonplaten Jul 14, 2025
c2dae7a
WIP
patrickvonplaten Jul 14, 2025
4c95480
WIP
patrickvonplaten Jul 14, 2025
8498ad4
WIP
patrickvonplaten Jul 14, 2025
0d4c6d9
WIP
patrickvonplaten Jul 14, 2025
ad23449
Your commit message
patrickvonplaten Jul 15, 2025
41b4a72
clean
patrickvonplaten Jul 15, 2025
db19ce6
clean
patrickvonplaten Jul 15, 2025
07bf91b
WIP
patrickvonplaten Jul 15, 2025
eee65ad
WIP
patrickvonplaten Jul 15, 2025
7ca0086
Your commit message
patrickvonplaten Jul 15, 2025
1058be3
clean
patrickvonplaten Jul 15, 2025
4c4b4f4
clean
patrickvonplaten Jul 15, 2025
6a9312e
Apply suggestions from code review
patrickvonplaten Jul 15, 2025
0659d13
up
patrickvonplaten Jul 15, 2025
5eebbd5
:wqallMerge branch 'add_voxtral' of https://github.com/patrickvonplat…
patrickvonplaten Jul 15, 2025
b77d3ad
clean
patrickvonplaten Jul 15, 2025
baa6129
clean
patrickvonplaten Jul 15, 2025
11574fa
clean
patrickvonplaten Jul 15, 2025
9b43fc5
WIP
patrickvonplaten Jul 15, 2025
4d03361
up
patrickvonplaten Jul 15, 2025
8dc5387
WIP
patrickvonplaten Jul 15, 2025
9ce248c
WIP
patrickvonplaten Jul 15, 2025
30b30a7
clean
patrickvonplaten Jul 15, 2025
d2cfe28
Merge branch 'main' of https://github.com/patrickvonplaten/vllm into …
patrickvonplaten Jul 15, 2025
311316d
Update tests/entrypoints/openai/test_transcription_validation.py
patrickvonplaten Jul 15, 2025
6a699ec
Update tests/models/registry.py
patrickvonplaten Jul 15, 2025
a446aee
Revert "Update tests/models/registry.py"
patrickvonplaten Jul 15, 2025
995a9ba
WIP
patrickvonplaten Jul 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions examples/offline_inference/audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from typing import Any, NamedTuple, Optional

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
Expand All @@ -30,7 +30,9 @@

class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
prompt: Optional[str] = None
prompt_token_ids: Optional[dict[str, list[int]]] = None
multi_modal_data: Optional[dict[str, Any]] = None
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None

Expand All @@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.


# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)

engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)

text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]

messages = [UserMessage(content=[*audio_chunks, text_chunk])]

req = ChatCompletionRequest(messages=messages, model=model_name)

tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios

audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

multi_modal_data = {"audio": audios_and_sr}

return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
)


# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
Expand Down Expand Up @@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:


model_example_map = {
"voxtral": run_voxtral,
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
Expand Down Expand Up @@ -311,16 +368,24 @@ def main(args):
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)

mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}
mm_data = req_data.multi_modal_data
if not mm_data:
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}

assert args.num_prompts > 0
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
inputs = {"multi_modal_data": mm_data}

if req_data.prompt:
inputs["prompt"] = req_data.prompt
else:
inputs["prompt_token_ids"] = req_data.prompt_token_ids

if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
Expand Down
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
importlib_metadata; python_version < '3.10'
mistral_common[opencv] >= 1.6.2
mistral_common[opencv] >= 1.8.0
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
Expand Down
2 changes: 1 addition & 1 deletion requirements/nightly_torch_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.6.2 # required for pixtral test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ torchvision==0.22.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.7.0 # required for pixtral test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
Expand Down
8 changes: 7 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.7.0
mistral-common==1.8.0
# via -r requirements/test.in
more-itertools==10.5.0
# via lm-eval
Expand Down Expand Up @@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
# via cffi
pycryptodomex==3.22.0
Expand All @@ -528,9 +530,12 @@ pydantic==2.11.5
# datamodel-code-generator
# mistral-common
# mteb
# pydantic-extra-types
# ray
pydantic-core==2.33.2
# via pydantic
pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0
# via rich
pyparsing==3.2.0
Expand Down Expand Up @@ -835,6 +840,7 @@ typing-extensions==4.12.2
# pqdm
# pydantic
# pydantic-core
# pydantic-extra-types
# torch
# typer
# typing-inspection
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ def _read_requirements(filename: str) -> list[str]:
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility
},
cmdclass=cmdclass,
Expand Down
28 changes: 23 additions & 5 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

from ...utils import RemoteOpenAIServer

MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
]


@pytest.fixture
def mary_had_lamb():
Expand All @@ -33,9 +38,18 @@ def winning_call():


@pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
@pytest.mark.parametrize(
"model_name",
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]

if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS

# TODO(PATRICK) - REMOVE AFTER RELEASE
return # skip for now

# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
Expand Down Expand Up @@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):


@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
async def test_long_audio_request(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]

if model_name.startswith("openai"):
return

mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process
Expand All @@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10
counts = out.count("Mary had a little lamb")
assert counts == 10, counts


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def check_available_online(
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501

# [Cross-encoder]
Expand Down Expand Up @@ -513,4 +514,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ async def _preprocess_speech_to_text(
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=lang,
task_type=self.task_type,
request_prompt=request.prompt)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):

@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
Expand Down
Loading