Skip to content

Commit 1fd1f32

Browse files
authored
fix dsv3 gate scaling (#3263)
1 parent 292793a commit 1fd1f32

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

lmdeploy/pytorch/models/deepseek_v2.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,14 @@ def forward(self, hidden_states: torch.Tensor):
325325
topk_weight = scores.gather(1, topk_idx)
326326
else:
327327
raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')
328-
if not self.renormalize:
328+
329+
if self.renormalize:
330+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
331+
topk_weight = topk_weight / denominator
332+
if not topk_weight.is_contiguous():
333+
topk_weight = topk_weight.contiguous()
334+
335+
if not self.renormalize or self.topk_method == 'noaux_tc':
329336
topk_weight = topk_weight * self.routed_scaling_factor
330337
return topk_weight, topk_idx
331338

@@ -354,7 +361,7 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
354361
self.ffn_dim,
355362
self.num_experts,
356363
top_k=self.top_k,
357-
renormalize=self.renormalize,
364+
renormalize=False,
358365
dtype=dtype,
359366
device=device,
360367
all_reduce=False,

0 commit comments

Comments
 (0)