Skip to content

Commit fd92c0c

Browse files
shaochangxumzusman
authored andcommitted
[Bugfix] fused_experts_impl wrong compute type for float32 (vllm-project#11921)
Signed-off-by: shaochangxu.scx <[email protected]> Co-authored-by: shaochangxu.scx <[email protected]>
1 parent ba48b5b commit fd92c0c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
701701
device=hidden_states.device,
702702
dtype=hidden_states.dtype)
703703

704-
compute_type = (tl.bfloat16
705-
if hidden_states.dtype == torch.bfloat16 else tl.float16)
704+
if hidden_states.dtype == torch.bfloat16:
705+
compute_type = tl.bfloat16
706+
elif hidden_states.dtype == torch.float16:
707+
compute_type = tl.float16
708+
elif hidden_states.dtype == torch.float32:
709+
compute_type = tl.float32
710+
else:
711+
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
706712

707713
if inplace:
708714
out_hidden_states = hidden_states

0 commit comments

Comments
 (0)