@@ -156,6 +156,38 @@ def sparse_attn_func(
156
156
return_softmax_lse = False ,
157
157
out = None ,
158
158
):
159
+ """Compute attention with virtical and slash sparsity patterns.
160
+ Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
161
+ block_count and block_offset for slash sparsity patterns, and
162
+ column_count and column_index for virtical sparsity patterns.
163
+
164
+ Arguments:
165
+ q: (batch_size, seqlen, nheads, headdim)
166
+ k: (batch_size, seqlen, nheads_k, headdim)
167
+ v: (batch_size, seqlen, nheads_k, headdim)
168
+ block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
169
+ block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
170
+ column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
171
+ column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
172
+ dropout_p: float. Dropout probability.
173
+ softmax_scale: float. The scaling of QK^T before applying softmax.
174
+ Default to 1 / sqrt(headdim).
175
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
176
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
177
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
178
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
179
+ is added to the attention score of query i and key j.
180
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
181
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
182
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
183
+ testing only. The returned probabilities are not guaranteed to be correct
184
+ (they might not have the right scaling).
185
+ Return:
186
+ out: (batch_size, seqlen, nheads, headdim).
187
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
188
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
189
+ normalization factor).
190
+ """
159
191
if softmax_scale is None :
160
192
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
161
193
out , softmax_lse = _sparse_attn_forward (
0 commit comments