@@ -84,17 +84,14 @@ def forward(
84
84
[self .enorm (input_embeds ),
85
85
self .hnorm (hidden_states )], dim = - 1 )
86
86
hidden_states = self .fc (inputs )
87
-
88
- # masking inputs at position=0
89
- hidden_states [positions == 0 ] = 0
90
87
residual = None
91
88
for layer in self .layers :
92
89
hidden_states , residual = layer (
93
90
positions ,
94
91
hidden_states ,
95
92
residual ,
96
93
)
97
- hidden_states , _ = self . norm ( hidden_states , residual )
94
+ hidden_states = residual + hidden_states
98
95
return hidden_states , hidden_states
99
96
100
97
def load_weights (self , weights : Iterable [tuple [str ,
@@ -103,6 +100,8 @@ def load_weights(self, weights: Iterable[tuple[str,
103
100
# (param_name, shard_name, shard_id)
104
101
("gate_up_proj" , "gate_proj" , 0 ),
105
102
("gate_up_proj" , "up_proj" , 1 ),
103
+ ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
104
+ ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
106
105
]
107
106
108
107
# Params for weights, fp8 weight scales, fp8 activation scales
@@ -131,7 +130,17 @@ def load_weights(self, weights: Iterable[tuple[str,
131
130
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
132
131
if ("mlp.experts." in name ) and name not in params_dict :
133
132
continue
134
- name = name .replace (weight_name , param_name )
133
+ name_mapped = name .replace (weight_name , param_name )
134
+
135
+ # QKV fusion is optional, fall back to normal
136
+ # weight loading if it's not enabled
137
+ # if go with fusion option, then update name
138
+ if ((param_name == "fused_qkv_a_proj" )
139
+ and name_mapped not in params_dict ):
140
+ continue
141
+ else :
142
+ name = name_mapped
143
+
135
144
# Skip loading extra bias for GPTQ models.
136
145
if name .endswith (".bias" ) and name not in params_dict :
137
146
continue
0 commit comments