We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ba48b5b commit fd92c0cCopy full SHA for fd92c0c
vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
701
device=hidden_states.device,
702
dtype=hidden_states.dtype)
703
704
- compute_type = (tl.bfloat16
705
- if hidden_states.dtype == torch.bfloat16 else tl.float16)
+ if hidden_states.dtype == torch.bfloat16:
+ compute_type = tl.bfloat16
706
+ elif hidden_states.dtype == torch.float16:
707
+ compute_type = tl.float16
708
+ elif hidden_states.dtype == torch.float32:
709
+ compute_type = tl.float32
710
+ else:
711
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
712
713
if inplace:
714
out_hidden_states = hidden_states
0 commit comments