Skip to content

Commit bdd361a

Browse files
authored
[FA2] fa2/hgemm manually smem swizzle🎉 (#185)
* Update flash_attn_mma.py * Create makefile * Create README.md * Update and rename matrix_trans_swizzle.cu to mat_trans_swizzle.cu * Update hgemm_mma_swizzle.cu * Update mat_trans_swizzle.cu * Update and rename flash_attn_mma_swizzle_qkv.cu to flash_attn_mma_share_kv_swizzle.cu * Create flash_attn_mma_share_qkv_swizzle.cu * Create flash_attn_mma_split_q_swizzle.cu * Create flash_attn_mma_split_kv_swizzle.cu * Create flash_attn_mma_tiling_qk_swizzle.cu * Create flash_attn_mma_tiling_qkv_swizzle.cu * Update flash_attn_mma_share_qkv_swizzle.cu * Update flash_attn_mma_split_kv_swizzle.cu * Update flash_attn_mma_split_q_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_tiling_qkv_swizzle.cu * Update README.md * Update hgemm_mma_swizzle.cu * Update makefile * Update README.md * Update README.md * Update mat_trans_swizzle.cu * Update makefile * Update hgemm_mma_swizzle.cu * Update hgemm_mma_swizzle.cu * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma.cu * Update makefile * Update utils.h * Create mma_simple_swizzle.cu * Update makefile * Update mma_simple_swizzle.cu * Update hgemm_mma_swizzle.cu * Update makefile * Update utils.py * Update makefile * Create hgemm_mma_stage_swizzle.cu * Update hgemm.py * Update hgemm.cc * Update mat_trans_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_share_kv_swizzle.cu * Update README.md * Update README.md * Create print_swizzle_layout.py * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_share_kv_swizzle.cu * Update README.md * Update hgemm_mma_stage_swizzle.cu * Update README.md * Update README.md * Update README.md * Update mma_simple_swizzle.cu * Create print_swizzle_layout.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 98984e0 commit bdd361a

27 files changed

+3940
-179
lines changed

README.md

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
3535
|✔️|✔️|✔️|✔️|
3636
|Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)|
3737
|✔️|✔️|✔️|✔️|
38-
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)|
38+
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe/MMA)|
3939
|✔️|✔️|✔️|✔️|
40-
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
40+
|Collective Store (Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
4141
|✔️|✔️|✔️|✔️|
4242

4343

@@ -48,7 +48,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
4848
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
4949
|:---:|:---:|:---:|:---:|
5050
|✔️|✔️|✔️|✔️|
51-
|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)|
51+
|Pack LDST (128 bits)|SMEM **Swizzle**/Padding |Copy Async|Tile MMA (More Threads)|
5252
|✔️|✔️|✔️|✔️|
5353
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
5454
|✔️|✔️|✔️|✔️|
@@ -160,7 +160,6 @@ The kernels listed here will guide you through a step-by-step progression, rangi
160160

161161
|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
162162
|:---|:---|:---|:---|:---|
163-
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
164163
| ✔️ [elementwise_f32](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
165164
| ✔️ [elementwise_f32x4](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
166165
| ✔️ [elementwise_f16](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
@@ -205,27 +204,27 @@ The kernels listed here will guide you through a step-by-step progression, rangi
205204
| ✔️ [mat_trans_f32_diagonal2d](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
206205
| ✔️ [mat_trans_f32x4_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
207206
| ✔️ [mat_trans_f32x4_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
208-
| ✔️ [warp_reduce_[all]](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
209-
| ✔️ [reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
210-
| ✔️ [reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
211-
| ✔️ [reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
212-
| ✔️ [reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
213-
| ✔️ [reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
214-
| ✔️ [reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
215-
| ✔️ [reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
216-
| ✔️ [reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
217-
| ✔️ [reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
218-
| ✔️ [reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
219-
| ✔️ [reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
220-
| ✔️ [reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
221-
| ✔️ [reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
222-
| ✔️ [reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
223-
| ✔️ [reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
224-
| ✔️ [reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
225-
| ✔️ [reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
226-
| ✔️ [reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
227-
| ✔️ [reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
228-
| ✔️ [reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
207+
| ✔️ [warp_reduce_{all}](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
208+
| ✔️ [block_all_reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
209+
| ✔️ [block_all_reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
210+
| ✔️ [block_all_reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
211+
| ✔️ [block_all_reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
212+
| ✔️ [block_all_reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
213+
| ✔️ [block_all_reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
214+
| ✔️ [block_all_reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
215+
| ✔️ [block_all_reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
216+
| ✔️ [block_all_reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
217+
| ✔️ [block_all_reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
218+
| ✔️ [block_all_reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
219+
| ✔️ [block_all_reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
220+
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
221+
| ✔️ [block_all_reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
222+
| ✔️ [block_all_reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
223+
| ✔️ [block_all_reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
224+
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
225+
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
226+
| ✔️ [block_all_reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
227+
| ✔️ [block_all_reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
229228
| ✔️ [dot_product_f32](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
230229
| ✔️ [dot_product_f32x4](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
231230
| ✔️ [dot_product_f16_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
@@ -262,7 +261,8 @@ The kernels listed here will guide you through a step-by-step progression, rangi
262261
| ✔️ [rms_norm_f16x8_pack_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
263262
| ✔️ [rms_norm_f16_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
264263
| ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️|
265-
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️|
264+
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️⭐️|
265+
| ✔️ [How to profile with nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️⭐️|
266266

267267
### 📚 Hard ⭐⭐⭐️ ([©️back👆🏻](#cuda-kernel))
268268

@@ -284,7 +284,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
284284
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
285285
| ✔️ [sgemm_t_8x8_sliced_k16...async](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
286286
| ✔️ [sgemm_wmma_m16n16k8...stages*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
287-
| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
287+
| ✔️ [sgemm_wmma_m16n16k8...swizzle{+block}*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
288288
| ✔️ [hgemm_naive_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
289289
| ✔️ [hgemm_sliced_k_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
290290
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./kernels/hgemm/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
@@ -299,12 +299,13 @@ The kernels listed here will guide you through a step-by-step progression, rangi
299299
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
300300
| ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
301301
| ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
302-
| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
302+
| ✔️ [hgemm_wmma_m16n16k16...swizzle{+block}*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
303303
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
304304
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
305305
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
306-
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
307-
| ✔️ [hgemm_mma_stages{swizzle}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
306+
| ✔️ [hgemm_mma_m16n8k16...swizzle{+block}*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
307+
| ✔️ [hgemm_mma_m16n8k16...swizzle{+smem}*](./kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
308+
| ✔️ [hgemm_mma_stages_swizzle{+smem}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
308309
| ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
309310

310311
### 📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ ([©️back👆🏻](#cuda-kernel))
@@ -318,11 +319,14 @@ The kernels listed here will guide you through a step-by-step progression, rangi
318319
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
319320
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
320321
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
322+
| ✔️ [flash_attn_mma...tiling_qk_swizzle{+smem}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
321323
| ? [flash_attn_mma_stages_split_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
322324
| ? [flash_attn_mma_stages_split_q{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_q_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
323325
| ? [flash_attn_mma_stages...shared_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
324326
| ? [flash_attn_mma_stages...shared_qkv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
325327
| ? [flash_attn_mma_stages...tiling_qk{f32}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
328+
| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️⭐️|
329+
326330
## 📖 博客目录
327331

328332
<div id="my-blogs-part-1"></div>

kernels/flash-attn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
|Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)|
66
|:---:|:---:|:---:|:---:|
77
|✔️|✔️|✔️|✔️|
8-
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
8+
|Pack LDST (pack 128 bits)|SMEM **Swizzle**/Padding |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
99
|✔️|✔️|✔️|✔️|
1010
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
1111
|✔️|✔️|✔️|✔️|

kernels/flash-attn/flash_attn_mma.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_args():
8383
'./mma/flash_attn_mma_share_kv.cu',
8484
'./mma/flash_attn_mma_share_qkv.cu',
8585
'./mma/flash_attn_mma_tiling_qk.cu',
86+
'./mma/flash_attn_mma_tiling_qk_swizzle.cu',
8687
'./pybind/flash_attn.cc'
8788
],
8889
extra_cuda_cflags=[
@@ -218,11 +219,11 @@ def run_benchmark(perf_func: callable,
218219
else:
219220
improve = 0
220221
MAX_TFLOPS = TFLOPS
221-
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
222+
print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
222223
f"TFLOPS:{TFLOPS:<6.2f}(+{improve:.2f}%)")
223224
else:
224-
if not only_show_improved or "flash" in tag:
225-
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
225+
if not only_show_improved or "flash" in tag or "sdpa" in tag:
226+
print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
226227
f"TFLOPS:{TFLOPS:<6.2f}")
227228

228229
if show_matrix: print(out)
@@ -296,7 +297,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
296297
diff = torch.abs(out_flash_or_sdpa.float() - out_mma.float())
297298
all_close = str(torch.allclose(out_flash_or_sdpa.float(), out_mma.float(), atol=1e-2))
298299
pretty_print_line(
299-
f"{true_tag} vs {tag:<18}, all close: {all_close:<6}, "
300+
f"{true_tag} vs {tag:<22}, all close: {all_close:<6}, "
300301
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
301302
f"mean diff: {diff.mean().item():.6f}"
302303
)
@@ -340,6 +341,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
340341
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
341342
out_mma_tiling_qk1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage1)", o, stages=1)
342343
out_mma_tiling_qk2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage2)", o, stages=2)
344+
out_mma_tiling_qk_sw1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage1)", o, stages=1)
345+
out_mma_tiling_qk_sw2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage2)", o, stages=2)
343346
if D <= 256:
344347
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
345348
if args.run_torch_sdpa:
@@ -360,9 +363,13 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
360363
check_all_close(out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all)
361364
check_all_close(out_flash, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all)
362365
check_all_close(out_flash, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all)
366+
check_all_close(out_flash, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all)
367+
check_all_close(out_flash, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all)
363368
pretty_print_line()
364369
elif args.run_torch_sdpa:
365370
pretty_print_line()
366371
check_all_close(out_sdpa, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all, False)
367372
check_all_close(out_sdpa, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all, False)
373+
check_all_close(out_sdpa, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all, False)
374+
check_all_close(out_sdpa, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all, False)
368375
pretty_print_line()

0 commit comments

Comments
 (0)