Skip to content

Commit ff69e59

Browse files
committed
Varlen mode for NSA layer
1 parent 31e1806 commit ff69e59

File tree

2 files changed

+58
-30
lines changed

2 files changed

+58
-30
lines changed

fla/layers/nsa.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from fla.modules import RotaryEmbedding
1414
from fla.ops.nsa.parallel import parallel_nsa
1515
from fla.ops.utils.index import prepare_lens_from_mask
16+
from fla.layers.utils import pad_input, unpad_input
1617

1718
if TYPE_CHECKING:
1819
from fla.models.utils import Cache
@@ -80,26 +81,24 @@ def forward(
8081
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
8182
)
8283

83-
batch_size, seq_len, _ = hidden_states.size()
84+
batch_size, q_len, _ = hidden_states.size()
8485

8586
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
8687
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
8788
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
8889
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
89-
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
9090

9191
cu_seqlens = kwargs.get('cu_seqlens', None)
9292

93-
seqlen_offset, max_seqlen = 0, seq_len
93+
seqlen_offset, max_seqlen = 0, q_len
9494
if past_key_values is not None:
9595
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
9696
max_seqlen = q.shape[1] + seqlen_offset
9797

98-
# Disable for now; varlen is not supported yet, and the "correct" RoPE offsets will disturb outputs
99-
# if attention_mask is not None:
100-
# # to deliminate the offsets of padding tokens
101-
# seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
102-
# max_seqlen = q.shape[1] + max(seqlen_offset)
98+
if attention_mask is not None:
99+
# to deliminate the offsets of padding tokens
100+
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
101+
max_seqlen = q.shape[1] + max(seqlen_offset)
103102

104103
if self.max_position_embeddings is not None:
105104
max_seqlen = max(max_seqlen, self.max_position_embeddings)
@@ -110,26 +109,46 @@ def forward(
110109
k_cached, v_cached = past_key_values.update(
111110
attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
112111
layer_idx=self.layer_idx,
113-
offset=seq_len,
112+
offset=q_len,
114113
)['attn_state']
115114
if cache_has_content:
116115
k, v = k_cached, v_cached
117116
k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
118117
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
119118

120-
o = parallel_nsa(
121-
q=q,
122-
k=k,
123-
v=v,
124-
g_cmp=g_cmp,
125-
g_slc=g_slc,
126-
g_swa=g_swa,
127-
block_size=self.block_size,
128-
block_counts=self.block_counts,
129-
window_size=self.window_size,
130-
cu_seqlens=cu_seqlens,
131-
)
132-
o = o.reshape(batch_size, seq_len, -1)
119+
if attention_mask is not None:
120+
(q, g), (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input(
121+
(q, g), (k, v), attention_mask, q_len, keepdim=True)
122+
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
123+
o = parallel_nsa(
124+
q=q,
125+
k=k,
126+
v=v,
127+
g_cmp=g_cmp,
128+
g_slc=g_slc,
129+
g_swa=g_swa,
130+
block_size=self.block_size,
131+
block_counts=self.block_counts,
132+
window_size=self.window_size,
133+
cu_seqlens=cu_seqlens,
134+
).squeeze(0)
135+
o = pad_input(o, indices_q, batch_size, q_len)
136+
else:
137+
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
138+
o = parallel_nsa(
139+
q=q,
140+
k=k,
141+
v=v,
142+
g_cmp=g_cmp,
143+
g_slc=g_slc,
144+
g_swa=g_swa,
145+
block_size=self.block_size,
146+
block_counts=self.block_counts,
147+
window_size=self.window_size,
148+
cu_seqlens=cu_seqlens,
149+
)
150+
151+
o = o.reshape(batch_size, q_len, -1)
133152
o = self.o_proj(o)
134153

135154
if not output_attentions:

fla/layers/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Code is adapted from flash-attn.bert_padding.py
55

6-
from typing import Tuple
6+
from typing import Tuple, Union
77

88
import torch
99
from einops import rearrange, repeat
@@ -99,7 +99,7 @@ def get_unpad_data(
9999

100100

101101
def unpad_input(
102-
q: torch.Tensor,
102+
q: Union[torch.Tensor, Tuple[torch.Tensor]],
103103
states: Tuple[torch.Tensor],
104104
attention_mask: torch.Tensor,
105105
q_len: int,
@@ -111,8 +111,9 @@ def unpad_input(
111111
112112
113113
Arguments:
114-
q (`torch.Tensor`):
114+
q (`torch.Tensor` or `Tuple[torch.Tensor]`):
115115
Query state with padding. Shape: [batch_size, q_len, ...].
116+
When it is a tuple, do unpadding for each tensor in the tuple.
116117
states (`Tuple[torch.Tensor]`):
117118
Attention state with padding. Shape: [batch_size, seq_len, ...].
118119
attention_mask (`torch.Tensor`):
@@ -123,9 +124,10 @@ def unpad_input(
123124
Whether to keep the batch dimension. Default: `False`.
124125
125126
Return:
126-
q (`torch.Tensor`):
127+
q (`torch.Tensor` or `Tuple[torch.Tensor]`):
127128
Query state without padding.
128129
Shape: [1, total_target_length, ...] if `keepdim=True` else [total_target_length, ...].
130+
When the `q` passed in is a tuple, return a tuple of such unpadded tensors.
129131
states (`Tuple[torch.Tensor]`):
130132
Attention state without padding.
131133
Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...].
@@ -146,23 +148,30 @@ def unpad_input(
146148
index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
147149
for s in states
148150
)
151+
if isinstance(q, torch.Tensor):
152+
q = (q,)
153+
cast_tuple = True
154+
else:
155+
cast_tuple = False
149156

150157
if q_len == seq_len:
151-
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
158+
q = tuple(index_first_axis(rearrange(q_, "b s ... -> (b s) ..."), indices_k) for q_ in q)
152159
cu_seqlens_q = cu_seqlens_k
153160
max_seqlen_in_batch_q = max_seqlen_in_batch_k
154161
indices_q = indices_k
155162
elif q_len == 1:
156163
max_seqlen_in_batch_q = 1
157-
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
164+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
158165
indices_q = cu_seqlens_q[:-1]
159-
q = q.squeeze(1)
166+
q = tuple(q_.squeeze(1) for q_ in q)
160167
else:
161168
raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
162169

163170
if keepdim:
164-
q = q.unsqueeze(0)
171+
q = tuple(q_.unsqueeze(0) for q_ in q)
165172
state = tuple(s.unsqueeze(0) for s in state)
173+
if cast_tuple:
174+
q = q[0]
166175

167176
return (
168177
q,

0 commit comments

Comments
 (0)