-
Notifications
You must be signed in to change notification settings - Fork 322
Expand file tree
/
Copy pathmodel.py
More file actions
47 lines (37 loc) · 1.72 KB
/
model.py
File metadata and controls
47 lines (37 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight
from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer
from lightllm.common.basemodel import TpPartBaseModel
class Qwen3MOEMTPModel(Qwen3MOEModel):
pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight
pre_layer_infer_class = Deepseek3MTPPreLayerInfer
def __init__(self, kvargs: dict):
self._pre_init(kvargs)
super().__init__(kvargs)
return
def _pre_init(self, kvargs: dict):
self.main_model: TpPartBaseModel = kvargs.pop("main_model")
self.mem_layer_start = kvargs.pop("mem_layer_start", 0)
return
def _init_custom(self):
self._cos_cached = self.main_model._cos_cached
self._sin_cached = self.main_model._sin_cached
return
def _init_req_manager(self):
self.req_manager = self.main_model.req_manager
return
def _init_mem_manager(self):
self.mem_manager = self.main_model.mem_manager
return
def _init_weights(self):
super()._init_weights()
self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_
self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_
return
def _init_infer_layer(self):
super()._init_infer_layer()
# reset the layer_num_ of the self.layers_infer
for layer in self.layers_infer:
layer.layer_num_ = layer.layer_num_ + self.mem_layer_start
return