Skip to content

Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 15, 2025
482 changes: 482 additions & 0 deletions csrc/flash_attn/flash_api.cpp

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ struct Flash_fwd_params : public Qkv_params {

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).

// For sparse attention
const int* block_count;
const int* block_offset;
const int* column_count;
const int* column_index;
int NUM_ROWS;
int NNZ_S;
int NNZ_V;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -189,6 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_sparse_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, true>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, false>(params, stream);
}
685 changes: 685 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_kernel.h

Large diffs are not rendered by default.

125 changes: 125 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_launch_template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/******************************************************************************
* Copyright (c) 2024, PAI, Alibaba Cloud.
******************************************************************************/

#pragma once

#include "flash_fwd_launch_template.h"
#include "flash_fwd_sparse_kernel.h"

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_sparse_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);

// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21

const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr bool IsEvenMNConst = false;
constexpr bool Is_local = false;
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
14 changes: 13 additions & 1 deletion csrc/flash_attn/src/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
}}
"""

KERNEL_IMPL_TEMPLATE_FWD_SPARSE = """#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_sparse_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""

KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"

template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
Expand Down Expand Up @@ -53,6 +61,10 @@ def template(self) -> str:
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "fwd_sparse":
return KERNEL_IMPL_TEMPLATE_FWD_SPARSE.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
Expand All @@ -68,7 +80,7 @@ def filename(self) -> str:


def get_all_kernels() -> List[Kernel]:
for direction in ["fwd", "fwd_split"]:
for direction in ["fwd", "fwd_split", "fwd_sparse"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
Expand Down
Loading