Skip to content

Commit 26a52a1

Browse files
add mamba
1 parent e3a64e4 commit 26a52a1

14 files changed

+720
-614
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*.
155155
- *Instruction Following*:
156156

157157
```py
158-
from mistral_inference.model import Transformer
158+
from mistral_inference.transformer import Transformer
159159
from mistral_inference.generate import generate
160160

161161
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
@@ -228,7 +228,7 @@ pip install --upgrade mistral-common
228228
You can simulate a code completion in-filling as follows.
229229

230230
```py
231-
from mistral_inference.model import Transformer
231+
from mistral_inference.transformer import Transformer
232232
from mistral_inference.generate import generate
233233
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
234234
from mistral_common.tokens.instruct.request import FIMRequest

moe_one_file_ref.py

+19-45
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MoeArgs(Serializable):
2222

2323

2424
@dataclass
25-
class ModelArgs(Serializable):
25+
class TransformerArgs(Serializable):
2626
dim: int
2727
n_layers: int
2828
head_dim: int
@@ -80,7 +80,7 @@ def apply_rotary_emb(
8080

8181

8282
class Attention(nn.Module):
83-
def __init__(self, args: ModelArgs):
83+
def __init__(self, args: TransformerArgs):
8484
super().__init__()
8585
self.args = args
8686

@@ -144,9 +144,7 @@ def forward(
144144
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
145145

146146
# Update cache
147-
scatter_pos = positions[None, :, None, None].repeat(
148-
bsz, 1, self.n_kv_heads, self.args.head_dim
149-
)
147+
scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
150148
cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
151149
cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)
152150

@@ -179,7 +177,7 @@ def forward(
179177

180178

181179
class FeedForward(nn.Module):
182-
def __init__(self, args: ModelArgs):
180+
def __init__(self, args: TransformerArgs):
183181
super().__init__()
184182
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
185183
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
@@ -214,9 +212,7 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs)
214212
def forward(self, inputs: torch.Tensor):
215213
inputs_squashed = inputs.view(-1, inputs.shape[-1])
216214
gate_logits = self.gate(inputs_squashed)
217-
weights, selected_experts = torch.topk(
218-
gate_logits, self.args.num_experts_per_tok
219-
)
215+
weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
220216
weights = nn.functional.softmax(
221217
weights,
222218
dim=1,
@@ -225,14 +221,12 @@ def forward(self, inputs: torch.Tensor):
225221
results = torch.zeros_like(inputs_squashed)
226222
for i, expert in enumerate(self.experts):
227223
batch_idx, nth_expert = torch.where(selected_experts == i)
228-
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
229-
inputs_squashed[batch_idx]
230-
)
224+
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs_squashed[batch_idx])
231225
return results.view_as(inputs)
232226

233227

234228
class TransformerBlock(nn.Module):
235-
def __init__(self, args: ModelArgs):
229+
def __init__(self, args: TransformerArgs):
236230
super().__init__()
237231
self.n_heads = args.n_heads
238232
self.dim = args.dim
@@ -270,7 +264,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
270264
class Transformer(nn.Module):
271265
def __init__(
272266
self,
273-
args: ModelArgs,
267+
args: TransformerArgs,
274268
pipeline_rank: int = 0,
275269
num_pipeline_ranks: int = 1,
276270
):
@@ -316,13 +310,9 @@ def freqs_cis(self) -> torch.Tensor:
316310
# from the module's dtype means we cannot register it as a buffer
317311
if self._precomputed_freqs_cis is None:
318312
theta = self.args.rope_theta or 1000000.0
319-
self._precomputed_freqs_cis = precompute_freqs_cis(
320-
self.args.head_dim, 128_000, theta
321-
)
313+
self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta)
322314
if self._precomputed_freqs_cis.device != self.device:
323-
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
324-
device=self.device
325-
)
315+
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
326316
return self._precomputed_freqs_cis
327317

328318
def forward(
@@ -341,9 +331,7 @@ def forward(
341331
assert h.shape == (bsz, seqlen, self.args.dim)
342332
assert h.dtype == self.dtype
343333
else:
344-
h = torch.empty(
345-
bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype
346-
)
334+
h = torch.empty(bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype)
347335
torch.distributed.recv(h, src=self.pipeline_rank - 1)
348336

349337
mask: Optional[torch.Tensor] = None
@@ -361,9 +349,7 @@ def forward(
361349

362350
if self.pipeline_rank < self.num_pipeline_ranks - 1:
363351
torch.distributed.send(h, dst=self.pipeline_rank + 1)
364-
outs = torch.empty(
365-
*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype
366-
)
352+
outs = torch.empty(*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype)
367353
else:
368354
assert self.output is not None
369355
assert self.norm is not None
@@ -422,7 +408,7 @@ def from_folder(
422408
dtype=torch.float16,
423409
) -> "Transformer":
424410
with open(folder / "params.json", "r") as f:
425-
model_args = ModelArgs.from_dict(json.load(f))
411+
model_args = TransformerArgs.from_dict(json.load(f))
426412
model_args.max_batch_size = max_batch_size
427413
model_args.max_seq_len = max_seq_len
428414
if num_pipeline_ranks > 1:
@@ -457,9 +443,7 @@ def from_folder(
457443

458444

459445
def load_tokenizer(model_path: Path) -> MistralTokenizer:
460-
tokenizer = [
461-
f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
462-
]
446+
tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
463447
assert (
464448
len(tokenizer) == 1
465449
), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer"
@@ -470,12 +454,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
470454

471455

472456
@torch.no_grad()
473-
def generate(
474-
prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
475-
):
476-
encoded_prompts = [
477-
tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
478-
]
457+
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
458+
encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
479459
prompt_lens = [len(x) for x in encoded_prompts]
480460
min_prompt_len = min(prompt_lens)
481461
max_prompt_len = max(prompt_lens)
@@ -498,23 +478,17 @@ def generate(
498478
# decode
499479
generated = []
500480
all_logprobs = [
501-
logprobs[:, :-1, :]
502-
.gather(2, input_tokens[:, 1:min_prompt_len, None])
503-
.squeeze(-1),
481+
logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
504482
]
505483
for cur_pos in range(min_prompt_len, max_tokens):
506484
next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
507485
if cur_pos < input_mask.shape[1]:
508-
next_token = torch.where(
509-
input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
510-
)
486+
next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
511487
all_logprobs.append(
512488
logprobs[:, -1, :].gather(1, next_token[:, None]),
513489
)
514490
generated.append(next_token[:, None])
515-
logits = model.forward(
516-
next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
517-
)
491+
logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
518492
logprobs = nn.functional.log_softmax(logits, dim=-1)
519493

520494
all_logprobs_merged = torch.cat(all_logprobs, 1)

one_file_ref.py

+17-39
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@dataclass
17-
class ModelArgs(Serializable):
17+
class TransformerArgs(Serializable):
1818
dim: int
1919
n_layers: int
2020
head_dim: int
@@ -31,9 +31,7 @@ class ModelArgs(Serializable):
3131
max_batch_size: int = 0
3232

3333

34-
def repeat_kv(
35-
keys: torch.Tensor, values: torch.Tensor, repeats: int
36-
) -> Tuple[torch.Tensor]:
34+
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int) -> Tuple[torch.Tensor]:
3735
keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
3836
values = torch.repeat_interleave(values, repeats=repeats, dim=2)
3937
return keys, values
@@ -68,7 +66,7 @@ def apply_rotary_emb(
6866

6967

7068
class Attention(nn.Module):
71-
def __init__(self, args: ModelArgs):
69+
def __init__(self, args: TransformerArgs):
7270
super().__init__()
7371
self.args = args
7472

@@ -118,9 +116,7 @@ def forward(
118116
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
119117

120118
# cache
121-
scatter_pos = positions[None, :, None, None].repeat(
122-
bsz, 1, self.n_kv_heads, self.args.head_dim
123-
)
119+
scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
124120
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
125121
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)
126122

@@ -152,7 +148,7 @@ def forward(
152148

153149

154150
class FeedForward(nn.Module):
155-
def __init__(self, args: ModelArgs):
151+
def __init__(self, args: TransformerArgs):
156152
super().__init__()
157153

158154
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
@@ -178,7 +174,7 @@ def forward(self, x):
178174

179175

180176
class TransformerBlock(nn.Module):
181-
def __init__(self, args: ModelArgs):
177+
def __init__(self, args: TransformerArgs):
182178
super().__init__()
183179
self.n_heads = args.n_heads
184180
self.dim = args.dim
@@ -210,7 +206,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
210206

211207

212208
class Transformer(nn.Module):
213-
def __init__(self, args: ModelArgs):
209+
def __init__(self, args: TransformerArgs):
214210
super().__init__()
215211
self.args = args
216212
self.vocab_size = args.vocab_size
@@ -219,18 +215,14 @@ def __init__(self, args: ModelArgs):
219215

220216
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
221217

222-
self.layers = torch.nn.ModuleList(
223-
[TransformerBlock(args=args) for _ in range(args.n_layers)]
224-
)
218+
self.layers = torch.nn.ModuleList([TransformerBlock(args=args) for _ in range(args.n_layers)])
225219

226220
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
227221

228222
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
229223

230224
theta = self.args.rope_theta or 1000000.0
231-
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to(
232-
"cuda"
233-
)
225+
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to("cuda")
234226

235227
def forward(
236228
self,
@@ -259,11 +251,9 @@ def forward(
259251
return self.output(self.norm(h)).float()
260252

261253
@staticmethod
262-
def from_folder(
263-
folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16
264-
):
254+
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16):
265255
with open(Path(folder) / "params.json", "r") as f:
266-
model_args = ModelArgs.from_dict(json.load(f))
256+
model_args = TransformerArgs.from_dict(json.load(f))
267257
model_args.max_batch_size = max_batch_size
268258

269259
model = Transformer(model_args)
@@ -288,9 +278,7 @@ def from_folder(
288278

289279

290280
def load_tokenizer(model_path: Path) -> MistralTokenizer:
291-
tokenizer = [
292-
f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
293-
]
281+
tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
294282
assert (
295283
len(tokenizer) > 0
296284
), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}."
@@ -304,12 +292,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
304292

305293

306294
@torch.no_grad()
307-
def generate(
308-
prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
309-
):
310-
encoded_prompts = [
311-
tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
312-
]
295+
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
296+
encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
313297
prompt_lens = [len(x) for x in encoded_prompts]
314298
min_prompt_len = min(prompt_lens)
315299
max_prompt_len = max(prompt_lens)
@@ -333,24 +317,18 @@ def generate(
333317
# decode
334318
generated = []
335319
all_logprobs = [
336-
logprobs[:, :-1, :]
337-
.gather(2, input_tokens[:, 1:min_prompt_len, None])
338-
.squeeze(-1),
320+
logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
339321
]
340322
cur_pos = min_prompt_len
341323
for _ in range(max_tokens):
342324
next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
343325
if cur_pos < input_mask.shape[1]:
344-
next_token = torch.where(
345-
input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
346-
)
326+
next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
347327
all_logprobs.append(
348328
logprobs[:, -1, :].gather(1, next_token[:, None]),
349329
)
350330
generated.append(next_token[:, None])
351-
logits = model.forward(
352-
next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
353-
)
331+
logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
354332
logprobs = nn.functional.log_softmax(logits, dim=-1)
355333
cur_pos += 1
356334

src/mistral_inference/args.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
from simple_parsing.helpers import Serializable
5+
6+
from mistral_inference.lora import LoraArgs
7+
from mistral_inference.moe import MoeArgs
8+
9+
10+
@dataclass
11+
class TransformerArgs(Serializable):
12+
dim: int
13+
n_layers: int
14+
head_dim: int
15+
hidden_dim: int
16+
n_heads: int
17+
n_kv_heads: int
18+
norm_eps: float
19+
vocab_size: int
20+
21+
max_batch_size: int = 0
22+
23+
# For rotary embeddings. If not set, will be inferred
24+
rope_theta: Optional[float] = None
25+
# If this is set, we will use MoE layers instead of dense layers.
26+
moe: Optional[MoeArgs] = None
27+
# If this is set, we will load LoRA linear layers instead of linear layers.
28+
lora: Optional[LoraArgs] = None
29+
model_type: str = "transformer"
30+
31+
def __post_init__(self):
32+
assert self.model_type == "transformer", self.model_type
33+
34+
35+
@dataclass
36+
class MambaArgs(Serializable):
37+
dim: int
38+
n_layers: int
39+
vocab_size: int
40+
n_groups: int
41+
rms_norm: bool
42+
residual_in_fp32: bool
43+
fused_add_norm: bool
44+
pad_vocab_size_multiple: int
45+
tie_embeddings: bool
46+
model_type: str = "mamba"
47+
48+
def __post_init__(self):
49+
assert self.model_type == "mamba", self.model_type

0 commit comments

Comments
 (0)