Skip to content

Commit

Permalink
[misc] remove python function call for custom activation op (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#11885)

Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
2 people authored and Ubuntu committed Jan 19, 2025
1 parent 6c2d820 commit 75d652b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 60 deletions.
27 changes: 0 additions & 27 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,6 @@ def register_fake(fn):
from torch.library import impl_abstract as register_fake


# activation ops
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)


def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)


def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down
79 changes: 46 additions & 33 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.fatrelu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
Expand All @@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x1 * x2

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
self.op(out, x, self.threshold)
return out


Expand Down Expand Up @@ -103,34 +103,35 @@ def __init__(self, approximate: str = "none"):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
self.op = torch.ops._C.gelu_tanh_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
if approximate == "none":
self.op = ipex_ops.gelu_and_mul
else:
self.op = ipex_ops.gelu_tanh_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.op(out, x)
return out

def extra_repr(self) -> str:
Expand All @@ -140,65 +141,77 @@ def extra_repr(self) -> str:
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_new
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_new

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

return ops.gelu_new(x)
return self.op(x)


@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_fast
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_fast

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

return ops.gelu_fast(x)
return self.op(x)


@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_quick
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_quick

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out

# TODO implement forward_xpu for QuickGELU
Expand Down

0 comments on commit 75d652b

Please sign in to comment.