Skip to content

Commit c776771

Browse files
committed
2 parents 77531f3 + c25e216 commit c776771

File tree

4 files changed

+29
-20
lines changed

4 files changed

+29
-20
lines changed

README.md

+1-6
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,11 @@ To test the correctness of NSA:
3636
pytest tests/test_nsa.py
3737
```
3838

39-
To validate the correctness of NSA with top‑k selection (ignoring the output from the compressed attention), run the command below. Please note that the initial trial may take some time as the kernel compiles, but subsequent runs will be faster.
39+
To validate the correctness of NSA with top‑k selection, run the command below. Please note that the initial trial may take some time as the kernel compiles, but subsequent runs will be faster.
4040
```py
4141
pytest tests/test_nsa_with_compression.py
4242
```
4343

44-
To verify the correctness of the top‑k selection, where sampling Q and K from a uniform distribution produces similar importance scores (resulting in slight variations in the top‑k selection), we validate this component separately. To run the test, execute:
45-
```py
46-
pytest tests/test_topk.py
47-
```
48-
4944
To measure the efficiency of NSA:
5045
```py
5146
python benchmark/benchmark_nsa.py

native_sparse_attention/ops/parallel.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ def parallel_nsa_bwd(
13671367
dk = dk.sum(0)
13681368
return dq, dk, dv
13691369

1370-
1370+
@torch.compile
13711371
class ParallelNSAFunction(torch.autograd.Function):
13721372

13731373
@staticmethod
@@ -1448,7 +1448,6 @@ def parallel_nsa_compression(
14481448
)
14491449

14501450

1451-
@torch.compile
14521451
def parallel_nsa(
14531452
q: torch.Tensor,
14541453
k: torch.Tensor,

tests/test_nsa.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,12 @@ def test_parallel(
5252
torch.manual_seed(42)
5353
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
5454

55-
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
56-
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
57-
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
55+
perm_q = torch.randperm(T, device='cuda')
56+
perm_k = torch.randperm(T, device='cuda')
57+
perm_v = torch.randperm(T, device='cuda')
58+
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)
59+
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
60+
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
5861
g_slc = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
5962
g_swa = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
6063
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
@@ -147,9 +150,12 @@ def test_parallel_varlen(
147150
torch.tensor([T], dtype=torch.long)
148151
], 0).cuda().sort()[0]
149152
# seq-first required for inputs with variable lengths
150-
q = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
151-
k = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
152-
v = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
153+
perm_q = torch.randperm(T, device='cuda')
154+
perm_k = torch.randperm(T, device='cuda')
155+
perm_v = torch.randperm(T, device='cuda')
156+
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
157+
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
158+
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
153159
g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
154160
g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
155161
do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda')

tests/test_nsa_with_compression.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import torch
77

8-
from native_sparse_attention.ops.naive import naive_nsa
8+
from native_sparse_attention.ops.naive import naive_nsa_with_compression
99
from native_sparse_attention.ops.parallel import parallel_nsa_with_compression
1010

1111

@@ -50,9 +50,13 @@ def test_parallel(
5050
torch.manual_seed(42)
5151
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
5252

53-
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
54-
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
55-
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
53+
54+
perm_q = torch.randperm(T, device='cuda')
55+
perm_k = torch.randperm(T, device='cuda')
56+
perm_v = torch.randperm(T, device='cuda')
57+
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda').view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)
58+
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda').view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
59+
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda').view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
5660
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
5761
g_slc = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
5862
g_swa = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
@@ -77,27 +81,31 @@ def test_parallel(
7781
tri_dk, k.grad = k.grad.clone(), None
7882
tri_dv, v.grad = v.grad.clone(), None
7983
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
84+
tri_dg_cmp, g_cmp.grad = g_cmp.grad.clone(), None
8085
if window_size > 0:
8186
tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None
8287

83-
ref = naive_nsa(
88+
ref, ref_topk = naive_nsa_with_compression(
8489
q=q,
8590
k=k,
8691
v=v,
92+
g_cmp=g_cmp,
8793
g_slc=g_slc,
8894
g_swa=g_swa,
89-
block_indices=tri_topk,
9095
block_counts=block_counts,
9196
block_size=block_size,
9297
window_size=window_size,
9398
scale=scale
9499
)
95100

101+
print((ref_topk != tri_topk[:, :, :, :ref_topk.shape[-1]]).float().mean())
102+
96103
ref.backward(do)
97104
ref_dq, q.grad = q.grad.clone(), None
98105
ref_dk, k.grad = k.grad.clone(), None
99106
ref_dv, v.grad = v.grad.clone(), None
100107
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
108+
ref_dg_cmp, g_cmp.grad = g_cmp.grad.clone(), None
101109
if window_size > 0:
102110
ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None
103111

@@ -106,6 +114,7 @@ def test_parallel(
106114
assert_close("dk", ref_dk, tri_dk, 0.005)
107115
assert_close("dv", ref_dv, tri_dv, 0.005)
108116
assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)
117+
assert_close("dg_cmp", ref_dg_cmp, tri_dg_cmp, 0.005)
109118
if window_size > 0:
110119
assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)
111120

0 commit comments

Comments
 (0)