Skip to content

Commit f021b97

Browse files
authored
[V1] Support Mistral3 in V1 (#15950)
Signed-off-by: mgoin <[email protected]>
1 parent 1cab43c commit f021b97

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

docs/source/models/supported_models.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ See [this page](#generative-models) for more information on how to use generativ
888888
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
889889
*
890890
* ✅︎
891-
*
891+
* ✅︎
892892
- * `MllamaForConditionalGeneration`
893893
* Llama 3.2
894894
* T + I<sup>+</sup>

vllm/model_executor/models/mistral3.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3232
from vllm.sequence import IntermediateTensors
3333

34-
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
35-
SupportsV0Only)
34+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3635
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
3736
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
3837
maybe_prefix, merge_multimodal_embeddings)
39-
from .vision import get_vision_encoder_info, select_patch_features
38+
from .vision import (get_vision_encoder_info, scatter_patch_features,
39+
select_patch_features)
4040

4141

4242
class Mistral3ImagePixelInputs(TypedDict):
@@ -425,7 +425,7 @@ def init_vision_tower_for_llava(
425425
info=_build_mistral3_info,
426426
dummy_inputs=Mistral3DummyInputsBuilder)
427427
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
428-
SupportsPP, SupportsV0Only):
428+
SupportsPP):
429429

430430
packed_modules_mapping = {
431431
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -518,7 +518,7 @@ def _parse_and_validate_image_input(
518518
return Mistral3ImagePixelInputs(
519519
type="pixel_values_pixtral",
520520
pixel_values=flatten_bn(pixel_values),
521-
embed_is_patch=embed_is_patch,
521+
embed_is_patch=flatten_bn(embed_is_patch),
522522
)
523523

524524
def _process_image_input(
@@ -557,7 +557,10 @@ def get_multimodal_embeddings(
557557

558558
vision_embeddings = self._process_image_input(image_input)
559559

560-
return vision_embeddings
560+
return scatter_patch_features(
561+
vision_embeddings,
562+
image_input["embed_is_patch"],
563+
)
561564

562565
def get_input_embeddings(
563566
self,

0 commit comments

Comments
 (0)