@@ -103,6 +103,8 @@ def load_weights(self, weights: Iterable[tuple[str,
103
103
# (param_name, shard_name, shard_id)
104
104
("gate_up_proj" , "gate_proj" , 0 ),
105
105
("gate_up_proj" , "up_proj" , 1 ),
106
+ ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
107
+ ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
106
108
]
107
109
108
110
# Params for weights, fp8 weight scales, fp8 activation scales
@@ -131,7 +133,17 @@ def load_weights(self, weights: Iterable[tuple[str,
131
133
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
132
134
if ("mlp.experts." in name ) and name not in params_dict :
133
135
continue
134
- name = name .replace (weight_name , param_name )
136
+ name_mapped = name .replace (weight_name , param_name )
137
+
138
+ # QKV fusion is optional, fall back to normal
139
+ # weight loading if it's not enabled
140
+ # if go with fusion option, then update name
141
+ if ((param_name == "fused_qkv_a_proj" )
142
+ and name_mapped not in params_dict ):
143
+ continue
144
+ else :
145
+ name = name_mapped
146
+
135
147
# Skip loading extra bias for GPTQ models.
136
148
if name .endswith (".bias" ) and name not in params_dict :
137
149
continue
0 commit comments