1313from fla .modules import RotaryEmbedding
1414from fla .ops .nsa .parallel import parallel_nsa
1515from fla .ops .utils .index import prepare_lens_from_mask
16+ from fla .layers .utils import pad_input , unpad_input
1617
1718if TYPE_CHECKING :
1819 from fla .models .utils import Cache
@@ -80,26 +81,24 @@ def forward(
8081 "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
8182 )
8283
83- batch_size , seq_len , _ = hidden_states .size ()
84+ batch_size , q_len , _ = hidden_states .size ()
8485
8586 q = rearrange (self .q_proj (hidden_states ), '... (h d) -> ... h d' , d = self .head_dim )
8687 k = rearrange (self .k_proj (hidden_states ), '... (h d) -> ... h d' , d = self .head_dim )
8788 v = rearrange (self .v_proj (hidden_states ), '... (h d) -> ... h d' , d = self .head_dim )
8889 g = rearrange (self .g_proj (hidden_states ), '... (h d) -> ... h d' , d = 3 )
89- g_cmp , g_slc , g_swa = g .sigmoid ().unbind (- 1 )
9090
9191 cu_seqlens = kwargs .get ('cu_seqlens' , None )
9292
93- seqlen_offset , max_seqlen = 0 , seq_len
93+ seqlen_offset , max_seqlen = 0 , q_len
9494 if past_key_values is not None :
9595 seqlen_offset = past_key_values .get_seq_length (self .layer_idx )
9696 max_seqlen = q .shape [1 ] + seqlen_offset
9797
98- # Disable for now; varlen is not supported yet, and the "correct" RoPE offsets will disturb outputs
99- # if attention_mask is not None:
100- # # to deliminate the offsets of padding tokens
101- # seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
102- # max_seqlen = q.shape[1] + max(seqlen_offset)
98+ if attention_mask is not None :
99+ # to deliminate the offsets of padding tokens
100+ seqlen_offset = seqlen_offset + prepare_lens_from_mask (attention_mask ) - attention_mask .shape [- 1 ]
101+ max_seqlen = q .shape [1 ] + max (seqlen_offset )
103102
104103 if self .max_position_embeddings is not None :
105104 max_seqlen = max (max_seqlen , self .max_position_embeddings )
@@ -110,26 +109,46 @@ def forward(
110109 k_cached , v_cached = past_key_values .update (
111110 attn_state = (k .flatten (- 2 , - 1 ), v .flatten (- 2 , - 1 )),
112111 layer_idx = self .layer_idx ,
113- offset = seq_len ,
112+ offset = q_len ,
114113 )['attn_state' ]
115114 if cache_has_content :
116115 k , v = k_cached , v_cached
117116 k = rearrange (k , '... (h d) -> ... h d' , d = self .head_dim )
118117 v = rearrange (v , '... (h d) -> ... h d' , d = self .head_dim )
119118
120- o = parallel_nsa (
121- q = q ,
122- k = k ,
123- v = v ,
124- g_cmp = g_cmp ,
125- g_slc = g_slc ,
126- g_swa = g_swa ,
127- block_size = self .block_size ,
128- block_counts = self .block_counts ,
129- window_size = self .window_size ,
130- cu_seqlens = cu_seqlens ,
131- )
132- o = o .reshape (batch_size , seq_len , - 1 )
119+ if attention_mask is not None :
120+ (q , g ), (k , v ), indices_q , cu_seqlens , max_seq_lens = unpad_input (
121+ (q , g ), (k , v ), attention_mask , q_len , keepdim = True )
122+ g_cmp , g_slc , g_swa = g .sigmoid ().unbind (- 1 )
123+ o = parallel_nsa (
124+ q = q ,
125+ k = k ,
126+ v = v ,
127+ g_cmp = g_cmp ,
128+ g_slc = g_slc ,
129+ g_swa = g_swa ,
130+ block_size = self .block_size ,
131+ block_counts = self .block_counts ,
132+ window_size = self .window_size ,
133+ cu_seqlens = cu_seqlens ,
134+ ).squeeze (0 )
135+ o = pad_input (o , indices_q , batch_size , q_len )
136+ else :
137+ g_cmp , g_slc , g_swa = g .sigmoid ().unbind (- 1 )
138+ o = parallel_nsa (
139+ q = q ,
140+ k = k ,
141+ v = v ,
142+ g_cmp = g_cmp ,
143+ g_slc = g_slc ,
144+ g_swa = g_swa ,
145+ block_size = self .block_size ,
146+ block_counts = self .block_counts ,
147+ window_size = self .window_size ,
148+ cu_seqlens = cu_seqlens ,
149+ )
150+
151+ o = o .reshape (batch_size , q_len , - 1 )
133152 o = self .o_proj (o )
134153
135154 if not output_attentions :
0 commit comments