@@ -22,7 +22,7 @@ class MoeArgs(Serializable):
22
22
23
23
24
24
@dataclass
25
- class ModelArgs (Serializable ):
25
+ class TransformerArgs (Serializable ):
26
26
dim : int
27
27
n_layers : int
28
28
head_dim : int
@@ -80,7 +80,7 @@ def apply_rotary_emb(
80
80
81
81
82
82
class Attention (nn .Module ):
83
- def __init__ (self , args : ModelArgs ):
83
+ def __init__ (self , args : TransformerArgs ):
84
84
super ().__init__ ()
85
85
self .args = args
86
86
@@ -144,9 +144,7 @@ def forward(
144
144
xq , xk = apply_rotary_emb (xq , xk , freqs_cis = freqs_cis )
145
145
146
146
# Update cache
147
- scatter_pos = positions [None , :, None , None ].repeat (
148
- bsz , 1 , self .n_kv_heads , self .args .head_dim
149
- )
147
+ scatter_pos = positions [None , :, None , None ].repeat (bsz , 1 , self .n_kv_heads , self .args .head_dim )
150
148
cache_k [:bsz ].scatter_ (dim = 1 , index = scatter_pos , src = xk )
151
149
cache_v [:bsz ].scatter_ (dim = 1 , index = scatter_pos , src = xv )
152
150
@@ -179,7 +177,7 @@ def forward(
179
177
180
178
181
179
class FeedForward (nn .Module ):
182
- def __init__ (self , args : ModelArgs ):
180
+ def __init__ (self , args : TransformerArgs ):
183
181
super ().__init__ ()
184
182
self .w1 = nn .Linear (args .dim , args .hidden_dim , bias = False )
185
183
self .w2 = nn .Linear (args .hidden_dim , args .dim , bias = False )
@@ -214,9 +212,7 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs)
214
212
def forward (self , inputs : torch .Tensor ):
215
213
inputs_squashed = inputs .view (- 1 , inputs .shape [- 1 ])
216
214
gate_logits = self .gate (inputs_squashed )
217
- weights , selected_experts = torch .topk (
218
- gate_logits , self .args .num_experts_per_tok
219
- )
215
+ weights , selected_experts = torch .topk (gate_logits , self .args .num_experts_per_tok )
220
216
weights = nn .functional .softmax (
221
217
weights ,
222
218
dim = 1 ,
@@ -225,14 +221,12 @@ def forward(self, inputs: torch.Tensor):
225
221
results = torch .zeros_like (inputs_squashed )
226
222
for i , expert in enumerate (self .experts ):
227
223
batch_idx , nth_expert = torch .where (selected_experts == i )
228
- results [batch_idx ] += weights [batch_idx , nth_expert , None ] * expert (
229
- inputs_squashed [batch_idx ]
230
- )
224
+ results [batch_idx ] += weights [batch_idx , nth_expert , None ] * expert (inputs_squashed [batch_idx ])
231
225
return results .view_as (inputs )
232
226
233
227
234
228
class TransformerBlock (nn .Module ):
235
- def __init__ (self , args : ModelArgs ):
229
+ def __init__ (self , args : TransformerArgs ):
236
230
super ().__init__ ()
237
231
self .n_heads = args .n_heads
238
232
self .dim = args .dim
@@ -270,7 +264,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
270
264
class Transformer (nn .Module ):
271
265
def __init__ (
272
266
self ,
273
- args : ModelArgs ,
267
+ args : TransformerArgs ,
274
268
pipeline_rank : int = 0 ,
275
269
num_pipeline_ranks : int = 1 ,
276
270
):
@@ -316,13 +310,9 @@ def freqs_cis(self) -> torch.Tensor:
316
310
# from the module's dtype means we cannot register it as a buffer
317
311
if self ._precomputed_freqs_cis is None :
318
312
theta = self .args .rope_theta or 1000000.0
319
- self ._precomputed_freqs_cis = precompute_freqs_cis (
320
- self .args .head_dim , 128_000 , theta
321
- )
313
+ self ._precomputed_freqs_cis = precompute_freqs_cis (self .args .head_dim , 128_000 , theta )
322
314
if self ._precomputed_freqs_cis .device != self .device :
323
- self ._precomputed_freqs_cis = self ._precomputed_freqs_cis .to (
324
- device = self .device
325
- )
315
+ self ._precomputed_freqs_cis = self ._precomputed_freqs_cis .to (device = self .device )
326
316
return self ._precomputed_freqs_cis
327
317
328
318
def forward (
@@ -341,9 +331,7 @@ def forward(
341
331
assert h .shape == (bsz , seqlen , self .args .dim )
342
332
assert h .dtype == self .dtype
343
333
else :
344
- h = torch .empty (
345
- bsz , seqlen , self .args .dim , device = self .device , dtype = self .dtype
346
- )
334
+ h = torch .empty (bsz , seqlen , self .args .dim , device = self .device , dtype = self .dtype )
347
335
torch .distributed .recv (h , src = self .pipeline_rank - 1 )
348
336
349
337
mask : Optional [torch .Tensor ] = None
@@ -361,9 +349,7 @@ def forward(
361
349
362
350
if self .pipeline_rank < self .num_pipeline_ranks - 1 :
363
351
torch .distributed .send (h , dst = self .pipeline_rank + 1 )
364
- outs = torch .empty (
365
- * h .shape [:- 1 ], self .vocab_size , device = h .device , dtype = h .dtype
366
- )
352
+ outs = torch .empty (* h .shape [:- 1 ], self .vocab_size , device = h .device , dtype = h .dtype )
367
353
else :
368
354
assert self .output is not None
369
355
assert self .norm is not None
@@ -422,7 +408,7 @@ def from_folder(
422
408
dtype = torch .float16 ,
423
409
) -> "Transformer" :
424
410
with open (folder / "params.json" , "r" ) as f :
425
- model_args = ModelArgs .from_dict (json .load (f ))
411
+ model_args = TransformerArgs .from_dict (json .load (f ))
426
412
model_args .max_batch_size = max_batch_size
427
413
model_args .max_seq_len = max_seq_len
428
414
if num_pipeline_ranks > 1 :
@@ -457,9 +443,7 @@ def from_folder(
457
443
458
444
459
445
def load_tokenizer (model_path : Path ) -> MistralTokenizer :
460
- tokenizer = [
461
- f for f in os .listdir (Path (model_path )) if f .startswith ("tokenizer.model" )
462
- ]
446
+ tokenizer = [f for f in os .listdir (Path (model_path )) if f .startswith ("tokenizer.model" )]
463
447
assert (
464
448
len (tokenizer ) == 1
465
449
), f"Multiple tokenizers { ', ' .join (tokenizer )} found in `model_path`, make sure to only have one tokenizer"
@@ -470,12 +454,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
470
454
471
455
472
456
@torch .no_grad ()
473
- def generate (
474
- prompts : List [str ], model : Transformer , tokenizer : Tokenizer , max_tokens : int
475
- ):
476
- encoded_prompts = [
477
- tokenizer .encode (prompt , bos = True , eos = False ) for prompt in prompts
478
- ]
457
+ def generate (prompts : List [str ], model : Transformer , tokenizer : Tokenizer , max_tokens : int ):
458
+ encoded_prompts = [tokenizer .encode (prompt , bos = True , eos = False ) for prompt in prompts ]
479
459
prompt_lens = [len (x ) for x in encoded_prompts ]
480
460
min_prompt_len = min (prompt_lens )
481
461
max_prompt_len = max (prompt_lens )
@@ -498,23 +478,17 @@ def generate(
498
478
# decode
499
479
generated = []
500
480
all_logprobs = [
501
- logprobs [:, :- 1 , :]
502
- .gather (2 , input_tokens [:, 1 :min_prompt_len , None ])
503
- .squeeze (- 1 ),
481
+ logprobs [:, :- 1 , :].gather (2 , input_tokens [:, 1 :min_prompt_len , None ]).squeeze (- 1 ),
504
482
]
505
483
for cur_pos in range (min_prompt_len , max_tokens ):
506
484
next_token = torch .argmax (logprobs [:, - 1 , :], dim = - 1 )
507
485
if cur_pos < input_mask .shape [1 ]:
508
- next_token = torch .where (
509
- input_mask [:, cur_pos ], input_tokens [:, cur_pos ], next_token
510
- )
486
+ next_token = torch .where (input_mask [:, cur_pos ], input_tokens [:, cur_pos ], next_token )
511
487
all_logprobs .append (
512
488
logprobs [:, - 1 , :].gather (1 , next_token [:, None ]),
513
489
)
514
490
generated .append (next_token [:, None ])
515
- logits = model .forward (
516
- next_token [:, None ], torch .LongTensor ([cur_pos ]).to (next_token )
517
- )
491
+ logits = model .forward (next_token [:, None ], torch .LongTensor ([cur_pos ]).to (next_token ))
518
492
logprobs = nn .functional .log_softmax (logits , dim = - 1 )
519
493
520
494
all_logprobs_merged = torch .cat (all_logprobs , 1 )
0 commit comments