22# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33
44import warnings
5- from typing import Optional , Union , Tuple
5+ from typing import Optional , Tuple , Union
66
77import torch
88import triton
@@ -130,8 +130,8 @@ def parallel_nsa_kernel_topk(
130130 o_i = tl .zeros ([BC ], dtype = tl .int32 )
131131 m_i = tl .arange (0 , BC ) < BC // 2
132132
133- IC = (i_t + Q_OFFSET ) // BS # Idx of the current query block
134- for i_c in range (0 , IC + 1 , BC ): # +1, because the current block might be also included
133+ IC = (i_t + Q_OFFSET ) // BS # Idx of the current query block
134+ for i_c in range (0 , IC + 1 , BC ): # +1, because the current block might be also included
135135 o_c = i_c + tl .arange (0 , BC )
136136 # Recall k: [B, TC, H, K], boc = i_b * TC
137137 # we first shift to k[i_b, 0, i_h], and read a block of transposed keys from k[i_b, i_c, i_h]
@@ -207,7 +207,7 @@ def parallel_nsa_fwd_kernel(
207207 IS_VARLEN : tl .constexpr ,
208208 USE_BLOCK_COUNTS : tl .constexpr
209209):
210- i_t , i_v , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 ) # i_t: token, i_v: value dim, i_bh: batch * kv head
210+ i_t , i_v , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 ) # i_t: token, i_v: value dim, i_bh: batch * kv head
211211 i_b , i_h = i_bh // H , i_bh % H
212212 # k: [B, TK, H, K], v: [B, TK, H, V], q: [B, TQ, HQ, K]
213213 # block_indices: [B, TQ, H, S]
@@ -259,7 +259,7 @@ def parallel_nsa_fwd_kernel(
259259 # p_q then reads the BK dimensions at the last dimension
260260 # the Q block is kept in the shared memory throughout the whole kernel
261261 # [G, BK]
262- b_q = tl .load (p_q , boundary_check = (0 , 1 )) # note that BK >= K, but there is boundary check
262+ b_q = tl .load (p_q , boundary_check = (0 , 1 )) # note that BK >= K, but there is boundary check
263263 b_q = (b_q * scale ).to (b_q .dtype )
264264
265265 p_o = tl .make_block_ptr (
@@ -275,10 +275,10 @@ def parallel_nsa_fwd_kernel(
275275 # [G, BV]
276276 b_o = tl .zeros ([G , BV ], dtype = tl .float32 )
277277
278- b_m = tl .full ([G ], float ('-inf' ), dtype = tl .float32 ) # running maximum
279- b_acc = tl .zeros ([G ], dtype = tl .float32 ) # sumexp
280- for i in range (NS ): # number of blocks
281- i_s = tl .load (block_indices + i ).to (tl .int32 ) * BS # i_s is the start token index of the current KV block
278+ b_m = tl .full ([G ], float ('-inf' ), dtype = tl .float32 ) # running maximum
279+ b_acc = tl .zeros ([G ], dtype = tl .float32 ) # sumexp
280+ for i in range (NS ): # number of blocks
281+ i_s = tl .load (block_indices + i ).to (tl .int32 ) * BS # i_s is the start token index of the current KV block
282282 # Here we assume that q tokens are last TQ tokens
283283 if i_s <= Q_OFFSET + i_t and i_s >= 0 :
284284 # Recall: k ([B, T, H, K]) already shifted to the start of the current sequence at head i_h, i.e. k[i_b, 0, i_h]
@@ -306,11 +306,10 @@ def parallel_nsa_fwd_kernel(
306306 # [G, BS]
307307 b_p = exp (b_s - b_m [:, None ])
308308 # [G]
309- b_acc = b_acc * b_r + tl .sum (b_p , 1 ) # summed over T dimension
309+ b_acc = b_acc * b_r + tl .sum (b_p , 1 ) # summed over T dimension
310310 # [G, BV]; note that b_p is fp32, while b_q may not
311311 b_o = b_o * b_r [:, None ] + tl .dot (b_p .to (b_q .dtype ), b_v )
312312
313-
314313 # o = o_n / a_n
315314 # lse = log( exp(m_n) * a_n )
316315
@@ -319,6 +318,7 @@ def parallel_nsa_fwd_kernel(
319318 tl .store (p_o , b_o .to (p_o .dtype .element_ty ), boundary_check = (0 , 1 ))
320319 tl .store (p_lse , b_m .to (p_lse .dtype .element_ty ))
321320
321+
322322@triton .heuristics ({
323323 'USE_BLOCK_COUNTS' : lambda args : isinstance (args ['block_counts' ], torch .Tensor )
324324})
@@ -548,6 +548,7 @@ def parallel_nsa_bwd_kernel_dkv(
548548 tl .store (p_dk , b_dk .to (p_dk .dtype .element_ty ), boundary_check = (0 , 1 ))
549549 tl .store (p_dv , b_dv .to (p_dv .dtype .element_ty ), boundary_check = (0 , 1 ))
550550
551+
551552@contiguous
552553def parallel_nsa_topk (
553554 q : torch .Tensor ,
@@ -557,7 +558,7 @@ def parallel_nsa_topk(
557558 block_counts : Union [torch .LongTensor , int ],
558559 block_size : int = 64 ,
559560 scale : float = None ,
560- cu_seqlens : Optional [ torch .LongTensor ] = None ,
561+ cu_seqlens : Union [ None , torch .LongTensor , Tuple [ torch . LongTensor , torch . LongTensor ] ] = None ,
561562) -> torch .LongTensor :
562563 B , TQ , HQ , K = q .shape
563564 _ , TC , H , _ = k .shape
@@ -610,6 +611,7 @@ def parallel_nsa_topk(
610611 )
611612 return block_indices
612613
614+
613615@contiguous
614616def parallel_nsa_fwd (
615617 q : torch .Tensor ,
@@ -655,7 +657,7 @@ def parallel_nsa_fwd(
655657 token_indices_q = token_indices_q ,
656658 TQ = T_q ,
657659 TK = T_kv ,
658- H = H ,
660+ H = H ,
659661 HQ = HQ ,
660662 G = G ,
661663 K = K ,
@@ -855,6 +857,7 @@ def backward(ctx, do):
855857 )
856858 return dq .to (q ), dk .to (k ), dv .to (v ), None , None , None , None , None , None , None , None
857859
860+
858861@contiguous
859862def parallel_nsa (
860863 q : torch .Tensor ,
@@ -868,7 +871,7 @@ def parallel_nsa(
868871 block_size : int = 64 ,
869872 window_size : int = 0 ,
870873 scale : Optional [float ] = None ,
871- cu_seqlens : Union [None , torch .LongTensor , Tuple [torch .LongTensor ]] = None ,
874+ cu_seqlens : Union [None , torch .LongTensor , Tuple [torch .LongTensor , torch . LongTensor ]] = None ,
872875) -> torch .Tensor :
873876 r"""
874877 Args:
@@ -888,7 +891,7 @@ def parallel_nsa(
888891 block_indices (torch.LongTensor):
889892 Block indices of shape `[B, TQ, H, S]`.
890893 `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
891- If `g_cmp` is provided, the passed `block_indices` will be ignored .
894+ Will override the computed block indices from compression if provided .
892895 block_counts (Optional[Union[torch.LongTensor, int]]):
893896 Number of selected blocks for each query.
894897 If a tensor is provided, with shape `[B, TQ, H]`,
@@ -901,9 +904,10 @@ def parallel_nsa(
901904 scale (Optional[float]):
902905 Scale factor for attention scores.
903906 If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
904- cu_seqlens (torch.LongTensor):
907+ cu_seqlens (torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor] or None ):
905908 Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
906909 consistent with the FlashAttention API.
910+ When a tuple is provided, it should contain two tensors: `(cu_seqlens_q, cu_seqlens_k)`.
907911
908912 Returns:
909913 o (torch.Tensor):
0 commit comments