Skip to content

Commit 4b03761

Browse files
committed
fix test
1 parent f268d75 commit 4b03761

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/ops/test_parallel_delta.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import pytest
44
import torch
5+
import torch.nn.functional as F
56

67
from fla.ops.delta_rule.parallel import parallel_delta_rule
7-
from fla.ops.delta_rule.wy_fast import fwd_prepare_T
8-
from fla.utils import device, device_platform
8+
from fla.ops.delta_rule.wy_fast import fwd_prepare_T, naive_delta_rule_parallel
9+
from fla.utils import assert_close, device, device_platform
910

1011
# IMPORTANT NOTE ON TENSOR FORMATS:
1112
# While the documentation for some functions states inputs should be in [B, T, H, K] format,
@@ -26,7 +27,6 @@
2627
for test in [
2728
(1, 2, 128, 64, torch.float16),
2829
(2, 4, 128, 32, torch.float16),
29-
(2, 4, 64, 128, torch.float16),
3030
]
3131
]
3232
)
@@ -46,7 +46,7 @@ def test_parallel_delta_rule(
4646

4747
# Generate test data
4848
q = torch.randn(B, H, T, K, dtype=dtype, device=device)
49-
k = torch.randn(B, H, T, K, dtype=dtype, device=device)
49+
k = F.normalize(torch.randn(B, H, T, K, dtype=dtype, device=device), p=2, dim=-1).to(dtype)
5050
v = torch.randn(B, H, T, K, dtype=dtype, device=device)
5151
beta = torch.randn(B, H, T, dtype=dtype, device=device).sigmoid()
5252
scale = 1.0 / (K ** 0.5)
@@ -74,10 +74,10 @@ def test_parallel_delta_rule(
7474
else:
7575
assert attn_parallel is None
7676

77-
# SKIPPED: Comparison with naive_delta_rule_parallel due to NaN issues
78-
# This requires fixing the naive implementation or replacing with another reference implementation
79-
# For now, we just verify that the parallel implementation runs without errors
80-
# assert_close('attn', attn_naive, attn_parallel, 0.01)
77+
o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone())
78+
79+
assert_close(' o', o_parallel, o_naive, 0.01)
80+
assert_close('attn', attn_naive, attn_parallel, 0.01)
8181

8282

8383
@pytest.mark.skipif(

0 commit comments

Comments
 (0)