Skip to content

Commit a4fe135

Browse files
authored
Merge pull request #239 from mistralai/add_patch_merger
Add support to Mistral Small 3.1
2 parents de6f646 + af0d803 commit a4fe135

File tree

6 files changed

+304
-57
lines changed

6 files changed

+304
-57
lines changed

README.md

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Blog Mathstral 7B: [https://mistral.ai/news/mathstral/](https://mistral.ai/news/
1515
Blog Nemo: [https://mistral.ai/news/mistral-nemo/](https://mistral.ai/news/mistral-nemo/) \
1616
Blog Mistral Large 2: [https://mistral.ai/news/mistral-large-2407/](https://mistral.ai/news/mistral-large-2407/) \
1717
Blog Pixtral 12B: [https://mistral.ai/news/pixtral-12b/](https://mistral.ai/news/pixtral-12b/)
18+
Blog Mistral Small 3.1: [https://mistral.ai/news/mistral-small-3-1/](https://mistral.ai/news/mistral-small-3-1/)
1819

1920
Discord: [https://discord.com/invite/mistralai](https://discord.com/invite/mistralai)\
2021
Documentation: [https://docs.mistral.ai/](https://docs.mistral.ai/)\
@@ -39,6 +40,8 @@ cd $HOME/mistral-inference && poetry install .
3940

4041
## Model download
4142

43+
### Direct links
44+
4245
| Name | Download | md5sum |
4346
|-------------|-------|-------|
4447
| 7B Instruct | https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar | `80b71fcb6416085bcb4efad86dfb4d52` |
@@ -54,16 +57,27 @@ cd $HOME/mistral-inference && poetry install .
5457
| Nemo Instruct | https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-instruct-2407.tar | `296fbdf911cb88e6f0be74cd04827fe7` |
5558
| Mistral Large 2 | https://models.mistralcdn.com/mistral-large-2407/mistral-large-instruct-2407.tar | `fc602155f9e39151fba81fcaab2fa7c4` |
5659

57-
Note:
60+
Note:
5861
- **Important**:
5962
- `mixtral-8x22B-Instruct-v0.3.tar` is exactly the same as [Mixtral-8x22B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), only stored in `.safetensors` format
6063
- `mixtral-8x22B-v0.3.tar` is the same as [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1), but has an extended vocabulary of 32768 tokens.
6164
- `codestral-22B-v0.1.tar` has a custom non-commercial license, called [Mistral AI Non-Production (MNPL) License](https://mistral.ai/licenses/MNPL-0.1.md)
6265
- `mistral-large-instruct-2407.tar` has a custom non-commercial license, called [Mistral AI Research (MRL) License](https://mistral.ai/licenses/MRL-0.1.md)
63-
- All of the listed models above support function calling. For example, Mistral 7B Base/Instruct v3 is a minor update to Mistral 7B Base/Instruct v2, with the addition of function calling capabilities.
64-
- The "coming soon" models will include function calling as well.
66+
- All of the listed models above support function calling. For example, Mistral 7B Base/Instruct v3 is a minor update to Mistral 7B Base/Instruct v2, with the addition of function calling capabilities.
67+
- The "coming soon" models will include function calling as well.
6568
- You can download the previous versions of our models from our [docs](https://docs.mistral.ai/getting-started/open_weight_models/#downloading).
6669

70+
### From Hugging Face Hub
71+
72+
| Name | ID | URL |
73+
|-------------|-------|-------|
74+
| Pixtral Large Instruct | mistralai/Pixtral-Large-Instruct-2411 | https://huggingface.co/mistralai/Pixtral-Large-Instruct-2411 |
75+
| Pixtral 12B Base | mistralai/Pixtral-12B-Base-2409 | https://huggingface.co/mistralai/Pixtral-12B-Base-2409 |
76+
| Pixtral 12B | mistralai/Pixtral-12B-2409 | https://huggingface.co/mistralai/Pixtral-12B-2409 |
77+
| Mistral Small 3.1 24B Base | mistralai/Mistral-Small-3.1-24B-Base-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503
78+
| Mistral Small 3.1 24B Instruct | mistralai/Mistral-Small-3.1-24B-Instruct-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503 |
79+
80+
6781
### Usage
6882

6983
**News!!!**: Mistral Large 2 is out. Read more about its capabilities [here](https://mistral.ai/news/mistral-large-2407/).
@@ -83,7 +97,7 @@ mkdir -p $12B_DIR
8397
tar -xf mistral-nemo-instruct-2407.tar -C $12B_DIR
8498
```
8599

86-
or
100+
or
87101

88102
```sh
89103
export M8x7B_DIR=$MISTRAL_MODEL/8x7b_instruct
@@ -92,6 +106,27 @@ mkdir -p $M8x7B_DIR
92106
tar -xf Mixtral-8x7B-v0.1-Instruct.tar -C $M8x7B_DIR
93107
```
94108

109+
For Hugging Face models' weights, here is an example to download [Mistral Small 3.1 24B Instruct](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503):
110+
111+
```python
112+
from pathlib import Path
113+
from huggingface_hub import snapshot_download
114+
115+
116+
mistral_models_path = Path.home().joinpath("mistral_models")
117+
118+
model_path = mistral_models_path / "mistral-small-3.1-instruct"
119+
model_path.mkdir(parents=True, exist_ok=True)
120+
121+
repo_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
122+
123+
snapshot_download(
124+
repo_id=repo_id,
125+
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
126+
local_dir=model_path,
127+
)
128+
```
129+
95130
## Usage
96131

97132
The following sections give an overview of how to run the model from the Command-line interface (CLI) or directly within Python.
@@ -170,7 +205,7 @@ To use [Codestral-Mamba](https://mistral.ai/news/codestral-mamba/) as a coding a
170205
Make sure `$7B_CODESTRAL_MAMBA` is set to a valid path to the downloaded codestral-mamba folder, e.g. `$HOME/mistral_models/mamba-codestral-7B-v0.1`.
171206

172207
You then need to additionally install the following packages:
173-
208+
174209
```
175210
pip install packaging mamba-ssm causal-conv1d transformers
176211
```
@@ -194,6 +229,19 @@ If you prompt it with *"Albert likes to surf every week. Each surfing session la
194229

195230
You can then continue chatting afterwards, *e.g.* with *"How much would he spend in a year?"*.
196231

232+
- **Chat with Mistral Small 3.1 24B Instruct**
233+
234+
To use [Mistral Small 3.1 24B Instruct](https://mistral.ai/news/mistral-small-3-1/) as an assistant you can run the following command using `mistral-chat`.
235+
Make sure `$MISTRAL_SMALL_3_1_INSTRUCT` is set to a valid path to the downloaded mistral small folder, e.g. `$HOME/mistral_models/mistral-small-3.1-instruct`
236+
237+
```sh
238+
mistral-chat $MISTRAL_SMALL_3_1_INSTRUCT --instruct --max_tokens 256
239+
```
240+
241+
If you prompt it with *"The above image presents an image of which park ? Please give the hints to identify the park."* with the following image URL *https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png*, the model should answer with the Yosemite park and give hints to identify it.
242+
243+
You can then continue chatting afterwards, *e.g.* with *"What is the name of the lake in the image?"*. The model should respond that it is not a lake but a river.
244+
197245
### Python
198246

199247
- *Instruction Following*:
@@ -222,6 +270,44 @@ result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
222270
print(result)
223271
```
224272

273+
- *Multimodal Instruction Following*:
274+
275+
276+
```python
277+
from pathlib import Path
278+
279+
from huggingface_hub import snapshot_download
280+
from mistral_common.protocol.instruct.messages import ImageURLChunk, TextChunk
281+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
282+
from mistral_inference.generate import generate
283+
from mistral_inference.transformer import Transformer
284+
285+
model_path = Path.home().joinpath("mistral_models") / "mistral-small-3.1-instruct" # change to extracted model
286+
287+
tokenizer = MistralTokenizer.from_file(model_path / "tekken.json")
288+
model = Transformer.from_folder(model_path)
289+
290+
url = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
291+
prompt = "The above image presents an image of which park ? Please give the hints to identify the park."
292+
293+
user_content = [ImageURLChunk(image_url=url), TextChunk(text=prompt)]
294+
295+
tokens, images = tokenizer.instruct_tokenizer.encode_user_content(user_content, False)
296+
297+
out_tokens, _ = generate(
298+
[tokens],
299+
model,
300+
images=[images],
301+
max_tokens=256,
302+
temperature=0.15,
303+
eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id,
304+
)
305+
result = tokenizer.decode(out_tokens[0])
306+
307+
print("Prompt:", prompt)
308+
print("Completion:", result)
309+
```
310+
225311
- *Function Calling*:
226312

227313
```py
@@ -298,7 +384,7 @@ print(middle)
298384

299385
### One-file-ref
300386

301-
If you want a self-contained implementation, look at `one_file_ref.py`, or run it with
387+
If you want a self-contained implementation, look at `one_file_ref.py`, or run it with
302388

303389
```
304390
python -m one_file_ref $M7B_DIR

src/mistral_inference/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from mistral_inference.lora import LoraArgs
77
from mistral_inference.moe import MoeArgs
88

9+
PATCH_MERGE = "patch_merge"
10+
911

1012
@dataclass
1113
class VisionEncoderArgs:
@@ -18,6 +20,10 @@ class VisionEncoderArgs:
1820
num_attention_heads: int
1921
rope_theta: float = 1e4 # for rope-2D
2022
image_token_id: int = 10
23+
adapter_bias: bool = True
24+
spatial_merge_size: int = 1
25+
add_pre_mm_projector_layer_norm: bool = False
26+
mm_projector_id: str = ""
2127

2228

2329
@dataclass

src/mistral_inference/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def interactive(
161161
length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
162162
else:
163163
length_tensor = torch.tensor([0], dtype=torch.int)
164+
images = []
164165

165166
if is_torchrun():
166167
dist.broadcast(length_tensor, src=0)

src/mistral_inference/transformer.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import torch
1010
from torch import nn
1111

12-
from mistral_inference.args import TransformerArgs
12+
from mistral_inference.args import PATCH_MERGE, TransformerArgs
1313
from mistral_inference.cache import BufferCache, CacheInputMetadata
1414
from mistral_inference.lora import LoRALoaderMixin
1515
from mistral_inference.model import ModelBase
1616
from mistral_inference.rope import precompute_freqs_cis
1717
from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
18-
from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer
18+
from mistral_inference.vision_encoder import PatchMerger, VisionLanguageAdapter, VisionTransformer
1919

2020

2121
@dataclass
@@ -58,9 +58,22 @@ def __init__(
5858

5959
self.vision_encoder: Optional[VisionTransformer] = None
6060
self.vision_language_adapter: Optional[VisionLanguageAdapter] = None
61+
6162
if args.vision_encoder is not None:
6263
self.vision_encoder = VisionTransformer(args.vision_encoder)
63-
self.vision_language_adapter = VisionLanguageAdapter(args.vision_encoder.hidden_size, args.dim)
64+
self.vision_language_adapter = VisionLanguageAdapter(
65+
args.vision_encoder.hidden_size, args.dim, args.vision_encoder.adapter_bias
66+
)
67+
68+
if args.vision_encoder.add_pre_mm_projector_layer_norm:
69+
self.pre_mm_projector_norm = RMSNorm(args.vision_encoder.hidden_size, eps=1e-5)
70+
71+
if args.vision_encoder.mm_projector_id == PATCH_MERGE:
72+
self.patch_merger = PatchMerger(
73+
vision_encoder_dim=args.vision_encoder.hidden_size,
74+
spatial_merge_size=args.vision_encoder.spatial_merge_size,
75+
)
76+
6477
if pipeline_rank == num_pipeline_ranks - 1:
6578
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
6679
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
@@ -106,7 +119,7 @@ def freqs_cis(self) -> torch.Tensor:
106119
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
107120
return self._precomputed_freqs_cis
108121

109-
def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[torch.tensor]) -> torch.Tensor: # type: ignore[valid-type]
122+
def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[torch.Tensor]) -> torch.Tensor:
110123
assert self.tok_embeddings is not None
111124
assert self.vision_encoder is not None
112125
assert self.vision_language_adapter is not None
@@ -115,16 +128,28 @@ def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[t
115128
text_locations = input_ids != self.args.vision_encoder.image_token_id
116129
image_locations = input_ids == self.args.vision_encoder.image_token_id
117130
text_features = self.tok_embeddings(input_ids[text_locations])
118-
image_features = self.vision_language_adapter(self.vision_encoder(images))
119131

120-
seq_len = input_ids.shape[0]
132+
image_features = self.vision_encoder(images)
133+
134+
if self.args.vision_encoder.add_pre_mm_projector_layer_norm:
135+
image_features = self.pre_mm_projector_norm(image_features)
136+
137+
if self.args.vision_encoder.mm_projector_id == PATCH_MERGE:
138+
patch_size = self.args.vision_encoder.patch_size
139+
img_patch_dims = [(img.shape[1] // patch_size, img.shape[2] // patch_size) for img in images]
140+
image_features = self.patch_merger(image_features, image_sizes=img_patch_dims)
141+
142+
image_features = self.vision_language_adapter(image_features)
143+
121144
N_txt, D_txt = text_features.shape
122145
N_img, D_img = image_features.shape
123146

147+
seq_len = input_ids.shape[0]
148+
124149
assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
125-
assert (
126-
seq_len == N_txt + N_img
127-
), f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}"
150+
assert seq_len == N_txt + N_img, (
151+
f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}"
152+
)
128153

129154
combined_features = torch.empty(
130155
(seq_len, D_txt),
@@ -147,9 +172,9 @@ def forward_partial(
147172
If doing pipeline parallelism, this will return the activations of the last layer of this stage.
148173
For the last stage, this will return the normalized final embeddings.
149174
"""
150-
assert (
151-
len(seqlens) <= self.args.max_batch_size
152-
), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
175+
assert len(seqlens) <= self.args.max_batch_size, (
176+
f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
177+
)
153178
(num_toks,) = input_ids.shape
154179
assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
155180

@@ -251,9 +276,19 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, as
251276
self.pipeline_rank,
252277
)
253278
skipped.add(k)
254-
elif k.startswith("vision_encoder") or k.startswith("vision_language_adapter"):
255-
assert not self.pipeline_rank
256-
state_to_load[k] = v
279+
elif any(
280+
k.startswith(key)
281+
for key in ["vision_encoder", "vision_language_adapter", "patch_merger", "pre_mm_projector_norm"]
282+
):
283+
if self.pipeline_rank == 0:
284+
state_to_load[k] = v
285+
else:
286+
logging.debug(
287+
"Skipping parameter %s at pipeline rank %d",
288+
k,
289+
self.pipeline_rank,
290+
)
291+
skipped.add(k)
257292
else:
258293
raise ValueError(f"Unexpected key {k}")
259294
assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
@@ -286,12 +321,12 @@ def from_folder(
286321
pt_model_file = Path(folder) / "consolidated.00.pth"
287322
safetensors_model_file = Path(folder) / "consolidated.safetensors"
288323

289-
assert (
290-
pt_model_file.exists() or safetensors_model_file.exists()
291-
), f"Make sure either {pt_model_file} or {safetensors_model_file} exists"
292-
assert not (
293-
pt_model_file.exists() and safetensors_model_file.exists()
294-
), f"Both {pt_model_file} and {safetensors_model_file} cannot exist"
324+
assert pt_model_file.exists() or safetensors_model_file.exists(), (
325+
f"Make sure either {pt_model_file} or {safetensors_model_file} exists"
326+
)
327+
assert not (pt_model_file.exists() and safetensors_model_file.exists()), (
328+
f"Both {pt_model_file} and {safetensors_model_file} cannot exist"
329+
)
295330

296331
if pt_model_file.exists():
297332
loaded = torch.load(str(pt_model_file), mmap=True)

0 commit comments

Comments
 (0)