@@ -42,24 +42,49 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
42
42
| Collective Store (Warp Shfl)| Row Major (NN)| Col Major (TN)| SGEMM FP32/TF32|
43
43
| ✔️| ✔️| ✔️| ✔️|
44
44
45
- I have also implemented ** FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [ flash-attention-mma⚡️⚡️] ( ./kernels/flash-attn ) for more details.
45
+
46
+ I have also implemented ** FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Sahred QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention ` (B<=4, H <=48, SeqLen <= 8192) ` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop.
46
47
47
48
![ flash-attn-mma] ( https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2 )
48
49
50
+ - Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
51
+ ``` bash
52
+ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
53
+ ------------------------------------------------------------------------------------------------------------------------
54
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
55
+ ------------------------------------------------------------------------------------------------------------------------
56
+ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
57
+ mma(split-kv+stage1): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:5.586338ms, TFLOPS:25.08
58
+ mma(split-kv+stage2): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:5.326223ms, TFLOPS:26.31
59
+ mma(split-q+stage1): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:3.834152ms, TFLOPS:36.54
60
+ mma(split-q+stage2): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:4.328346ms, TFLOPS:32.37
61
+ mma(split-q+share-kv+stage1): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:2.636528ms, TFLOPS:53.15
62
+ mma(split-q+share-qkv+stage1): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:2.594471ms, TFLOPS:54.01
63
+ mma(split-q+share-qkv+stage2): [' 0.01960754 ' , ' 0.01452637 ' , ' -0.02592468 ' ], time:2.574611ms, TFLOPS:54.42
64
+ (flash): [' 0.01963806 ' , ' 0.0145874 ' , ' -0.02593994 ' ], time:3.764462ms, TFLOPS:37.22
65
+ -----------------------------------------------------------------------------------------------------------------------
66
+ ```
67
+
68
+ However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [ flash-attention-mma⚡️⚡️] ( ./kernels/flash-attn ) for more details.
49
69
50
- | CUDA Cores| Sliced K ( Loop over N/D) | Tile Block (Br, Bc, Bd )| MMA (m16n8k16)|
70
+ | Tensor Cores| Loop over Seqlen/Headdim | Tile Block (Br, Bc)| MMA (m16n8k16)|
51
71
| :---:| :---:| :---:| :---:|
52
72
| ✔️| ✔️| ✔️| ✔️|
53
- |Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMAs (More Threads)
73
+ |Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)
54
74
| ✔️| ✔️| ✔️| ✔️|
55
- | Tile Warps (More Values)| Multi Stages (1/2)| Collective Store (Shfl)| ** Split KV/Q** |
75
+ | Tile Warp (More Values)| Multi Stages (1/2)| Collective Store (Shfl)| ** Split KV/Q** |
56
76
| ✔️| ✔️| ✔️| ✔️|
77
+ | ** Shared KV** SMEM| Fully ** Shared QKV** SMEM| ** Prefetch Q** s2r| SMEM/Block Swizzle|
78
+ | ✔️| ✔️| ✔️| ?|
57
79
58
- The ` Split KV ` and ` Split Q ` implementations have been carried out in [ flash-attention-mma⚡️⚡️] ( ./kernels/flash-attn ) for performance comparison. The ` Split KV ` method, which involves splitting all QKV across MMA (Warps), is slower than ` Split Q ` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
80
+ The ` Split KV ` and ` Split Q ` implementations have been carried out in [ flash-attention-mma⚡️⚡️] ( ./kernels/flash-attn ) for performance comparison. The ` Split KV ` method, which involves splitting all QKV across MMA (Warps), is slower than ` Split Q ` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
59
81
60
- ![ flash-attn] ( https://github.com/user-attachments/assets/11490fbc-2a4a-4630-abe8-91a9d1251cba )
61
82
83
+ <!--
84
+ 
85
+ -->
62
86
- 📚 Split KV (Basic, FlashAttention-1)
87
+ <div id =" mma-split-kv " ></div >
63
88
64
89
``` C++
65
90
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
@@ -69,22 +94,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
69
94
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
70
95
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
71
96
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
72
- template <
73
- const int kHeadDim , // Headdim, 32,64,128
74
- const int kMmaAtomM , // MMA Atom M, 16
75
- const int kMmaAtomN , // MMA Atom N, 8
76
- const int kMmaAtomK , // MMA Atom K, 16
77
- const int kMmaTileSeqLenQ , // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
78
- const int kMmaTileSeqLenK , // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
79
- const int kMmaTileSeqLenP , // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
80
- const int kMmaTileHeadDimV , // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
81
- const int kWarpTileSeqLenQ , // 2, more values, M, Br=32*2=64, matmul M
82
- const int kWarpTileSeqLenK , // 2, more values, N, Bc=32*2=64, matmul N
83
- const int kWarpTileSeqLenP , // 2, more values, M, Br=32*2=64, matmul M
84
- const int kWarpTileHeadDimV , // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|...
85
- const int kStage , // only support 1 or 2 now.
86
- const int kPad // 0,8
87
- >
88
97
__global__ void
89
98
flash_attn_mma_stages_split_kv_kernel (half* Q, // [ B, H, N, D]
90
99
half* K, // [ B, H, D, N] K^T transposed
@@ -94,6 +103,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
94
103
```
95
104
96
105
- 📚 Split Q (Faster, FlashAttention-2)
106
+ <div id="mma-split-q"></div>
97
107
98
108
```C++
99
109
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
@@ -104,29 +114,40 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
104
114
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
105
115
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
106
116
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
107
- template<
108
- const int kHeadDim, // Headdim, 32,64,128
109
- const int kMmaAtomM, // MMA Atom M, 16
110
- const int kMmaAtomN, // MMA Atom N, 8
111
- const int kMmaAtomK, // MMA Atom K, 16
112
- const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
113
- const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
114
- const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
115
- const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
116
- const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
117
- const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
118
- const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
119
- const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
120
- const int kStage, // only support 1 or 2 now.
121
- const int kPad // 0,8
122
- >
123
117
__global__ void
124
118
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
125
119
half* K, // [B, H, D, N] K^T transposed
126
120
half* V, // [B, H, N, D]
127
121
half* O, // [B, H, N, D]
128
122
int QKV_seqlen);
129
123
```
124
+
125
+ - 📚 Split Q + Shared KV SMEM (Faster+)
126
+ <div id =" mma-share-kv " ></div >
127
+
128
+ ``` C++
129
+ // K, V shared the same shared memory, improve block occupancy.
130
+ __global__ void
131
+ flash_attn_mma_stages_split_q_shared_kv_kernel (half* Q,
132
+ half* K,
133
+ half* V,
134
+ half* O,
135
+ int QKV_seqlen);
136
+ ```
137
+ - 📚 Split Q + Fully Shared QKV SMEM (Faster++)
138
+
139
+ <div id="mma-share-qkv"></div>
140
+
141
+ ```C++
142
+ // Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
143
+ __global__ void
144
+ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
145
+ half* K,
146
+ half* V,
147
+ half* O,
148
+ int QKV_seqlen);
149
+ ```
150
+
130
151
## ©️Citations🎉🎉
131
152
132
153
``` BibTeX
@@ -144,11 +165,13 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
144
165
145
166
<div id =" cuda-kernel " ></div >
146
167
147
- | 📖 CUDA Kernel| 📖 Elem dtype | 📖 Acc dtype | 📖 Docs | 📖 Level |
168
+ | 📖 CUDA Kernel| 📖 Elem DType | 📖 Acc DType | 📖 Docs | 📖 Level |
148
169
| :---| :---| :---| :---| :---|
149
170
| ✔️ [ nsys/ncu(timeline/ptx/sass)] ( ./kernels/nvidia-nsight/ ) | /| /| [ link] ( ./kernels/nvidia-nsight/ ) | ⭐️|
150
171
| ✔️ [ flash_attn_mma_stages_split_kv* ] ( ./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
151
172
| ✔️ [ flash_attn_mma_stages_split_q* ] ( ./kernels/flash-attn/mma/flash_attn_mma_split_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
173
+ | ✔️ [ flash_attn_mma_stages...shared_kv* ] ( ./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️|
174
+ | ✔️ [ flash_attn_mma_stages...shared_qkv* ] ( ./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️|
152
175
| ✔️ [ sgemm_naive_f32] ( ./kernels/sgemm/sgemm.cu ) | f32| f32| [ link] ( ./kernels/sgemm/ ) | ⭐️⭐️|
153
176
| ✔️ [ sgemm_sliced_k_f32] ( ./kernels/sgemm/sgemm.cu ) | f32| f32| [ link] ( ./kernels/sgemm/ ) | ⭐️⭐️⭐️|
154
177
| ✔️ [ sgemm_t_8x8_sliced_k_f32x4] ( ./kernels/sgemm/sgemm.cu ) | f32| f32| [ link] ( ./kernels/sgemm/ ) | ⭐️⭐️⭐️|
0 commit comments