diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index afb350591e562..d04cbbc0a9eed 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -34,33 +34,6 @@ def register_fake(fn): from torch.library import impl_abstract as register_fake -# activation ops -def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_and_mul(out, x) - - -def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_tanh_and_mul(out, x) - - -def fatrelu_and_mul(out: torch.Tensor, - x: torch.Tensor, - threshold: float = 0.0) -> None: - torch.ops._C.fatrelu_and_mul(out, x, threshold) - - -def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_fast(out, x) - - -def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_new(out, x) - - -def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_quick(out, x) - - # page attention ops def paged_attention_v1( out: torch.Tensor, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 32456fee06a28..2475190d197d3 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp): def __init__(self, threshold: float = 0.): super().__init__() self.threshold = threshold + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.fatrelu_and_mul def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return x1 * x2 def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - ops.fatrelu_and_mul(out, x, self.threshold) + self.op(out, x, self.threshold) return out @@ -103,6 +103,17 @@ def __init__(self, approximate: str = "none"): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if approximate == "none": + self.op = torch.ops._C.gelu_and_mul + elif approximate == "tanh": + self.op = torch.ops._C.gelu_tanh_and_mul + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + if approximate == "none": + self.op = ipex_ops.gelu_and_mul + else: + self.op = ipex_ops.gelu_tanh_and_mul def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -110,27 +121,17 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - if self.approximate == "none": - ops.gelu_and_mul(out, x) - elif self.approximate == "tanh": - ops.gelu_tanh_and_mul(out, x) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops - d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - if self.approximate == "none": - ops.gelu_and_mul(out, x) - elif self.approximate == "tanh": - ops.gelu_tanh_and_mul(out, x) + self.op(out, x) return out def extra_repr(self) -> str: @@ -140,6 +141,14 @@ def extra_repr(self) -> str: @CustomOp.register("gelu_new") class NewGELU(CustomOp): + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_new + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_new + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" c = math.sqrt(2.0 / math.pi) @@ -147,58 +156,62 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: (x + 0.044715 * torch.pow(x, 3.0)))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - out = torch.empty_like(x) - ops.gelu_new(out, x) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops - - return ops.gelu_new(x) + return self.op(x) @CustomOp.register("gelu_fast") class FastGELU(CustomOp): + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_fast + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_fast + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - out = torch.empty_like(x) - ops.gelu_fast(out, x) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops - - return ops.gelu_fast(x) + return self.op(x) @CustomOp.register("quick_gelu") class QuickGELU(CustomOp): # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_quick + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_quick + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" return x * torch.sigmoid(1.702 * x) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - out = torch.empty_like(x) - ops.gelu_quick(out, x) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_quick(out, x) + self.op(out, x) return out # TODO implement forward_xpu for QuickGELU