Skip to content

Commit 3d3f5e7

Browse files
committed
support moe
1 parent 00c1bde commit 3d3f5e7

File tree

1 file changed

+31
-3
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+31
-3
lines changed

vllm/model_executor/layers/fused_moe/layer.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
7171
layer.register_parameter("w2_weight", w2_weight)
7272
set_weight_attrs(w2_weight, extra_weight_attrs)
7373

74+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
75+
super().process_weights_after_loading(layer)
76+
77+
if current_platform.is_cpu():
78+
import intel_extension_for_pytorch as ipex
79+
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(layer.w13_weight,
80+
layer.w2_weight,
81+
use_prepack=True)
82+
7483
def apply(
7584
self,
7685
layer: torch.nn.Module,
@@ -122,9 +131,28 @@ def forward_cuda(
122131
topk_ids=topk_ids,
123132
inplace=True)
124133

125-
def forward_cpu(self, *args, **kwargs):
126-
raise NotImplementedError(
127-
"The CPU backend currently does not support MoE.")
134+
def forward_cpu(
135+
self,
136+
layer: torch.nn.Module,
137+
x: torch.Tensor,
138+
use_grouped_topk: bool,
139+
top_k: int,
140+
router_logits: torch.Tensor,
141+
renormalize: bool,
142+
topk_group: Optional[int] = None,
143+
num_expert_group: Optional[int] = None,
144+
custom_routing_function: Optional[Callable] = None,
145+
**kwargs,
146+
):
147+
return layer.ipex_fusion(
148+
x,
149+
use_grouped_topk,
150+
top_k,
151+
router_logits,
152+
renormalize,
153+
topk_group,
154+
num_expert_group,
155+
)
128156

129157
def forward_tpu(
130158
self,

0 commit comments

Comments
 (0)