Skip to content

Commit b8c9490

Browse files
committed
Review changes
Signed-off-by: Xin Yang <[email protected]>
1 parent e90d5ac commit b8c9490

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,14 @@ def forward(
8484
[self.enorm(input_embeds),
8585
self.hnorm(hidden_states)], dim=-1)
8686
hidden_states = self.fc(inputs)
87-
88-
# masking inputs at position=0
89-
hidden_states[positions == 0] = 0
9087
residual = None
9188
for layer in self.layers:
9289
hidden_states, residual = layer(
9390
positions,
9491
hidden_states,
9592
residual,
9693
)
97-
hidden_states, _ = self.norm(hidden_states, residual)
94+
hidden_states = residual + hidden_states
9895
return hidden_states, hidden_states
9996

10097
def load_weights(self, weights: Iterable[tuple[str,
@@ -103,6 +100,8 @@ def load_weights(self, weights: Iterable[tuple[str,
103100
# (param_name, shard_name, shard_id)
104101
("gate_up_proj", "gate_proj", 0),
105102
("gate_up_proj", "up_proj", 1),
103+
("fused_qkv_a_proj", "q_a_proj", 0),
104+
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
106105
]
107106

108107
# Params for weights, fp8 weight scales, fp8 activation scales
@@ -131,7 +130,17 @@ def load_weights(self, weights: Iterable[tuple[str,
131130
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
132131
if ("mlp.experts." in name) and name not in params_dict:
133132
continue
134-
name = name.replace(weight_name, param_name)
133+
name_mapped = name.replace(weight_name, param_name)
134+
135+
# QKV fusion is optional, fall back to normal
136+
# weight loading if it's not enabled
137+
# if go with fusion option, then update name
138+
if ((param_name == "fused_qkv_a_proj")
139+
and name_mapped not in params_dict):
140+
continue
141+
else:
142+
name = name_mapped
143+
135144
# Skip loading extra bias for GPTQ models.
136145
if name.endswith(".bias") and name not in params_dict:
137146
continue

0 commit comments

Comments
 (0)