Skip to content

Commit ebeff9e

Browse files
committed
only keep hdim128
Signed-off-by: Minmin Sun <[email protected]>
1 parent c9d548f commit ebeff9e

30 files changed

+5
-285
lines changed

csrc/flash_attn/flash_api.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,11 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
236236

237237
void run_mha_fwd_sparse(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
238238
TORCH_CHECK(params.num_splits <= 1 && !force_split_kernel, "run_mha_fwd_sparse does not support splitkv.");
239+
TORCH_CHECK(params.d == 128, "run_mha_fwd_sparse only supports headdim=128 for now to keep binary small.");
239240
FP16_SWITCH(!params.is_bf16, [&] {
240-
HEADDIM_SWITCH(params.d, [&] {
241-
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
242-
run_mha_fwd_sparse_<elem_type, kHeadDim, Is_causal>(params, stream);
243-
});
241+
constexpr static int kHeadDim = 128;
242+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
243+
run_mha_fwd_sparse_<elem_type, kHeadDim, Is_causal>(params, stream);
244244
});
245245
});
246246
}

csrc/flash_attn/src/flash_fwd_sparse_hdim160_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim160_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim160_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim160_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim192_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim192_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim192_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim192_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim224_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim224_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim224_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim224_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim256_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim256_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim256_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim256_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim32_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim32_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim32_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim32_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim64_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim64_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim64_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim64_fp16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim96_bf16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim96_bf16_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim96_fp16_causal_sm80.cu

-10
This file was deleted.

csrc/flash_attn/src/flash_fwd_sparse_hdim96_fp16_sm80.cu

-10
This file was deleted.

tests/test_vllm_flash_attn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_varlen_with_paged_kv(
270270

271271
@pytest.mark.parametrize("seq_lens", [(1023, 2049), (1023, 1023), (32, 32), (65, 65), (129, 129)])
272272
@pytest.mark.parametrize("num_heads", [1, 2, 4])
273-
@pytest.mark.parametrize("head_size", [64, 128, 256])
273+
@pytest.mark.parametrize("head_size", [128])
274274
@pytest.mark.parametrize("dtype", DTYPES)
275275
@pytest.mark.parametrize("NNZ_S", [1, 2, 3, 7, 15, 32])
276276
@torch.inference_mode()

0 commit comments

Comments
 (0)