22
33import pytest
44import torch
5+ import torch .nn .functional as F
56
67from fla .ops .delta_rule .parallel import parallel_delta_rule
7- from fla .ops .delta_rule .wy_fast import fwd_prepare_T
8- from fla .utils import device , device_platform
8+ from fla .ops .delta_rule .wy_fast import fwd_prepare_T , naive_delta_rule_parallel
9+ from fla .utils import assert_close , device , device_platform
910
1011# IMPORTANT NOTE ON TENSOR FORMATS:
1112# While the documentation for some functions states inputs should be in [B, T, H, K] format,
2627 for test in [
2728 (1 , 2 , 128 , 64 , torch .float16 ),
2829 (2 , 4 , 128 , 32 , torch .float16 ),
29- (2 , 4 , 64 , 128 , torch .float16 ),
3030 ]
3131 ]
3232)
@@ -46,7 +46,7 @@ def test_parallel_delta_rule(
4646
4747 # Generate test data
4848 q = torch .randn (B , H , T , K , dtype = dtype , device = device )
49- k = torch .randn (B , H , T , K , dtype = dtype , device = device )
49+ k = F . normalize ( torch .randn (B , H , T , K , dtype = dtype , device = device ), p = 2 , dim = - 1 ). to ( dtype )
5050 v = torch .randn (B , H , T , K , dtype = dtype , device = device )
5151 beta = torch .randn (B , H , T , dtype = dtype , device = device ).sigmoid ()
5252 scale = 1.0 / (K ** 0.5 )
@@ -74,10 +74,10 @@ def test_parallel_delta_rule(
7474 else :
7575 assert attn_parallel is None
7676
77- # SKIPPED: Comparison with naive_delta_rule_parallel due to NaN issues
78- # This requires fixing the naive implementation or replacing with another reference implementation
79- # For now, we just verify that the parallel implementation runs without errors
80- # assert_close('attn', attn_naive, attn_parallel, 0.01)
77+ o_naive , attn_naive = naive_delta_rule_parallel ( q . clone (), k . clone (), v . clone (), beta . clone ())
78+
79+ assert_close ( ' o' , o_parallel , o_naive , 0.01 )
80+ assert_close ('attn' , attn_naive , attn_parallel , 0.01 )
8181
8282
8383@pytest .mark .skipif (
0 commit comments