Skip to content

Commit 096dd4a

Browse files
add permute cols opcheck
1 parent c452a86 commit 096dd4a

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tests/kernels/test_permute_cols.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
import torch
3+
4+
from tests.kernels.utils import opcheck
5+
from vllm._custom_ops import permute_cols
6+
7+
8+
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
9+
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
10+
def test_permute_cols(shape, dtype):
11+
x = torch.randn(shape, dtype=dtype).cuda()
12+
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
13+
opcheck(torch.ops._C.permute_cols, (x, perm))
14+
y = permute_cols(x, perm)
15+
torch.testing.assert_close(y, x[:, perm])

vllm/_custom_ops.py

+12
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
576576
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
577577

578578

579+
# TODO: has to be a better way to do this
580+
try:
581+
torch.ops._C.permute_cols # noqa B018
582+
583+
@torch.library.register_fake("_C::permute_cols")
584+
def _permute_cols_fake(a: torch.Tensor,
585+
perm: torch.Tensor) -> torch.Tensor:
586+
return torch.empty_like(a)
587+
except Exception:
588+
pass
589+
590+
579591
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
580592
return torch.ops._C.permute_cols(a, perm)
581593

0 commit comments

Comments
 (0)