Skip to content

Commit 92ae9df

Browse files
committed
Refine
1 parent 8701243 commit 92ae9df

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

csrc/flash_attn/flash_api.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ void set_params_fprop_sparse(Flash_fwd_params &params,
206206
seqlenq_ngroups_swapped,
207207
unpadded_lse
208208
);
209-
params.block_count = (int*)block_count.data_ptr();
210-
params.block_offset = (int*)block_offset.data_ptr();
211-
params.column_count = (int*)column_count.data_ptr();
212-
params.column_index = (int*)column_index.data_ptr();
209+
params.block_count = block_count.const_data_ptr<int>();
210+
params.block_offset = block_offset.const_data_ptr<int>();
211+
params.column_count = column_count.const_data_ptr<int>();
212+
params.column_index = column_index.const_data_ptr<int>();
213213
TORCH_CHECK(block_count.size(2) == block_offset.size(2));
214214
TORCH_CHECK(column_index.size(2) == block_offset.size(2));
215215
TORCH_CHECK(column_count.size(2) == column_index.size(2));

vllm_flash_attn/flash_attn_interface.py

+32
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,38 @@ def sparse_attn_func(
156156
return_softmax_lse=False,
157157
out=None,
158158
):
159+
"""Compute attention with virtical and slash sparsity patterns.
160+
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
161+
block_count and block_offset for slash sparsity patterns, and
162+
column_count and column_index for virtical sparsity patterns.
163+
164+
Arguments:
165+
q: (batch_size, seqlen, nheads, headdim)
166+
k: (batch_size, seqlen, nheads_k, headdim)
167+
v: (batch_size, seqlen, nheads_k, headdim)
168+
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
169+
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
170+
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
171+
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
172+
dropout_p: float. Dropout probability.
173+
softmax_scale: float. The scaling of QK^T before applying softmax.
174+
Default to 1 / sqrt(headdim).
175+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
176+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
177+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
178+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
179+
is added to the attention score of query i and key j.
180+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
181+
which is slightly slower and uses more memory. The forward pass is always deterministic.
182+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
183+
testing only. The returned probabilities are not guaranteed to be correct
184+
(they might not have the right scaling).
185+
Return:
186+
out: (batch_size, seqlen, nheads, headdim).
187+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
188+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
189+
normalization factor).
190+
"""
159191
if softmax_scale is None:
160192
softmax_scale = q.shape[-1] ** (-0.5)
161193
out, softmax_lse = _sparse_attn_forward(

0 commit comments

Comments
 (0)