@@ -19,10 +19,9 @@ from_float(half& d, float s) { d = __float2half(s); }
19
19
static __forceinline__ __device__ void
20
20
from_float (__nv_bfloat16& d, float s) { d = __float2bfloat16 (s); }
21
21
22
- // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
23
- // can be used to combine partial attention results (in the split-KV case)
24
- template <typename scalar_t , bool kLoopOverHead >
25
- __global__ void merge_attn_states_kernel (
22
+
23
+ template <typename scalar_t >
24
+ __device__ __forceinline__ void merge_attn_states_per_thread (
26
25
scalar_t * output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
27
26
float * output_lse, // [NUM_HEADS, NUM_TOKENS]
28
27
const scalar_t * __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
@@ -33,140 +32,129 @@ __global__ void merge_attn_states_kernel(
33
32
const float * __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
34
33
const uint num_tokens, // NUM_TOKENS
35
34
const uint num_heads, // NUM QUERY HEADS
36
- const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
35
+ const uint head_size, // HEAD_SIZE, 32,48,64,...,512,etc
36
+ const uint token_idx,
37
+ const uint head_idx,
38
+ const uint thr_idx
37
39
) {
38
- // TODO(DefTruth): may need to support fp8?
39
- if constexpr (kLoopOverHead ) {
40
- // May loop over num heads for large NUM_TOKENS
41
- const uint token_idx = blockIdx .x ;
42
- const uint thread_idx = threadIdx .x ;
40
+ using pack_128b_t = uint4 ; // float -> 4, half/bf16 -> 8
41
+ constexpr uint pack_size = 16 / sizeof (scalar_t );
43
42
44
- #pragma unroll
45
- for (uint head_idx = 0 ; head_idx < num_heads; ++head_idx) {
46
- float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
47
- float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
48
- p_lse =
49
- std::isinf (p_lse) ? -std::numeric_limits<float >::infinity () : p_lse;
50
- s_lse =
51
- std::isinf (s_lse) ? -std::numeric_limits<float >::infinity () : s_lse;
43
+ const uint thr_offset = thr_idx * pack_size; // (0~15)*8, etc.
44
+ const uint blk_offset =
45
+ token_idx * num_heads * head_size + head_idx * head_size;
46
+ const scalar_t * prefix_output_blk = prefix_output + blk_offset;
47
+ const scalar_t * suffix_output_blk = suffix_output + blk_offset;
48
+ scalar_t * output_blk = output + blk_offset;
52
49
53
- const float max_lse = fmaxf (p_lse, s_lse);
54
- p_lse = p_lse - max_lse;
55
- s_lse = s_lse - max_lse;
56
- const float p_se = expf (p_lse);
57
- const float s_se = expf (s_lse);
58
- const float out_se = p_se + s_se;
50
+ float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
51
+ float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
52
+ p_lse = std::isinf (p_lse) ? -std::numeric_limits<float >::infinity () : p_lse;
53
+ s_lse = std::isinf (s_lse) ? -std::numeric_limits<float >::infinity () : s_lse;
59
54
60
- if (output_lse != nullptr ) {
61
- float out_lse = logf (out_se) + max_lse;
62
- output_lse[head_idx * num_tokens + token_idx] = out_lse;
63
- }
55
+ const float max_lse = fmaxf (p_lse, s_lse);
56
+ p_lse = p_lse - max_lse;
57
+ s_lse = s_lse - max_lse;
58
+ const float p_se = expf (p_lse);
59
+ const float s_se = expf (s_lse);
60
+ const float out_se = p_se + s_se;
61
+ const float p_scale = p_se / out_se;
62
+ const float s_scale = s_se / out_se;
64
63
65
- const uint blk_offset =
66
- token_idx * num_heads * head_size + head_idx * head_size;
67
- const scalar_t * prefix_output_blk = prefix_output + blk_offset ;
68
- const scalar_t * suffix_output_blk = suffix_output + blk_offset ;
69
- scalar_t * output_blk = output + blk_offset;
64
+ // We only need to write to output_lse once per head.
65
+ if (output_lse != nullptr && thr_idx == 0 ) {
66
+ float out_lse = logf (out_se) + max_lse ;
67
+ output_lse[head_idx * num_tokens + token_idx] = out_lse ;
68
+ }
70
69
71
- // float -> 4, half/bf16 -> 8
72
- using pack_128b_t = uint4 ;
73
- constexpr uint pack_size = 16 / sizeof (scalar_t );
70
+ if (thr_offset < head_size) {
71
+ // Pack 128b load
72
+ pack_128b_t p_out_pack = reinterpret_cast <const pack_128b_t *>(
73
+ prefix_output_blk)[thr_offset / pack_size];
74
+ pack_128b_t s_out_pack = reinterpret_cast <const pack_128b_t *>(
75
+ suffix_output_blk)[thr_offset / pack_size];
76
+ pack_128b_t o_out_pack;
74
77
75
- const uint thr_offset = thread_idx * pack_size;
76
- const float p_scale = p_se / out_se;
77
- const float s_scale = s_se / out_se;
78
+ #pragma unroll
79
+ for (uint i = 0 ; i < pack_size; ++i) {
80
+ // Always use float for FMA to keep precision.
81
+ // half(uint16_t), bfloat16, float -> float.
82
+ const float p_out_f =
83
+ to_float (reinterpret_cast <const scalar_t *>(&p_out_pack)[i]);
84
+ const float s_out_f =
85
+ to_float (reinterpret_cast <const scalar_t *>(&s_out_pack)[i]);
86
+ // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
87
+ const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
88
+ // float -> half(uint16_t), bfloat16, float.
89
+ from_float (reinterpret_cast <scalar_t *>(&o_out_pack)[i],
90
+ o_out_f);
91
+ }
78
92
79
- if (thr_offset < head_size) {
80
- // Pack 128b load
81
- pack_128b_t p_out_pack = reinterpret_cast <const pack_128b_t *>(
82
- prefix_output_blk)[thr_offset / pack_size];
83
- pack_128b_t s_out_pack = reinterpret_cast <const pack_128b_t *>(
84
- suffix_output_blk)[thr_offset / pack_size];
85
- pack_128b_t o_out_pack;
93
+ // Pack 128b storage
94
+ reinterpret_cast <pack_128b_t *>(output_blk)[
95
+ thr_offset / pack_size] = o_out_pack;
96
+ }
97
+ }
86
98
87
- #pragma unroll
88
- for (uint i = 0 ; i < pack_size; ++i) {
89
- // Always use float for FMA to keep precision.
90
- // half(uint16_t), bfloat16, float -> float.
91
- const float p_out_f =
92
- to_float (reinterpret_cast <const scalar_t *>(&p_out_pack)[i]);
93
- const float s_out_f =
94
- to_float (reinterpret_cast <const scalar_t *>(&s_out_pack)[i]);
95
- // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
96
- const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
97
- // float -> half(uint16_t), bfloat16, float.
98
- from_float (reinterpret_cast <scalar_t *>(&o_out_pack)[i],
99
- o_out_f);
100
- }
99
+ // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
100
+ // can be used to combine partial attention results (in the split-KV case)
101
+ template <typename scalar_t , bool kLoopOverHead , bool kFlattenOverHead = false >
102
+ __global__ void merge_attn_states_kernel (
103
+ scalar_t * output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
104
+ float * output_lse, // [NUM_HEADS, NUM_TOKENS]
105
+ const scalar_t * __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
106
+ // HEAD_SIZE]
107
+ const float * __restrict__ prefix_lse, // [NUM_HEADS, NUM_TOKENS]
108
+ const scalar_t * __restrict__ suffix_output, // [NUM_TOKENS, NUM_HEADS,
109
+ // HEAD_SIZE]
110
+ const float * __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
111
+ const uint num_tokens, // NUM_TOKENS
112
+ const uint num_heads, // NUM QUERY HEADS
113
+ const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
114
+ ) {
115
+ if constexpr (kLoopOverHead ) {
116
+ // May loop over num heads for large num_tokens
117
+ const uint token_idx = blockIdx .x ;
118
+ const uint thread_idx = threadIdx .x ;
119
+
120
+ if constexpr (kFlattenOverHead ) {
121
+ // thread num = (num_heads * head_size) / pack_size
122
+ // = num_heads * (head_size / pack_size), 16 * (128 / 8)
123
+ // tid: 0~255, 0~15->head 0, 16~31->head 1, ..., etc.
124
+ constexpr uint pack_size = 16 / sizeof (scalar_t );
125
+ const uint head_idx = thread_idx / (head_size / pack_size);
126
+ const uint thr_idx = thread_idx % (head_size / pack_size);
127
+ merge_attn_states_per_thread<scalar_t >(
128
+ output, output_lse, prefix_output,
129
+ prefix_lse, suffix_output, suffix_lse,
130
+ num_tokens, num_heads, head_size,
131
+ token_idx, head_idx, thr_idx
132
+ );
133
+ } else {
134
+ const uint thr_idx = thread_idx;
135
+ #pragma unroll
136
+ for (uint head_idx = 0 ; head_idx < num_heads; ++head_idx) {
137
+ merge_attn_states_per_thread<scalar_t >(
138
+ output, output_lse, prefix_output,
139
+ prefix_lse, suffix_output, suffix_lse,
140
+ num_tokens, num_heads, head_size,
141
+ token_idx, head_idx, thr_idx
142
+ );
143
+ } // End loop over heads
144
+ } // End kFlattenOverHead
101
145
102
- // Pack 128b storage
103
- reinterpret_cast <pack_128b_t *>(output_blk)[
104
- thr_offset / pack_size] = o_out_pack;
105
- }
106
- } // End loop over heads
107
146
} else {
108
147
const uint token_idx = blockIdx .x ;
109
148
const uint head_idx = blockIdx .y ;
110
149
const uint thread_idx = threadIdx .x ;
150
+ const uint thr_idx = thread_idx;
111
151
112
- float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
113
- float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
114
- p_lse = std::isinf (p_lse) ? -std::numeric_limits<float >::infinity () : p_lse;
115
- s_lse = std::isinf (s_lse) ? -std::numeric_limits<float >::infinity () : s_lse;
116
-
117
- const float max_lse = fmaxf (p_lse, s_lse);
118
- p_lse = p_lse - max_lse;
119
- s_lse = s_lse - max_lse;
120
- const float p_se = expf (p_lse);
121
- const float s_se = expf (s_lse);
122
- const float out_se = p_se + s_se;
123
-
124
- if (output_lse != nullptr ) {
125
- float out_lse = logf (out_se) + max_lse;
126
- output_lse[head_idx * num_tokens + token_idx] = out_lse;
127
- }
128
-
129
- const uint blk_offset =
130
- token_idx * num_heads * head_size + head_idx * head_size;
131
- const scalar_t * prefix_output_blk = prefix_output + blk_offset;
132
- const scalar_t * suffix_output_blk = suffix_output + blk_offset;
133
- scalar_t * output_blk = output + blk_offset;
134
-
135
- // float -> 4, half/bf16 -> 8
136
- using pack_128b_t = uint4 ; // 16 bytes
137
- constexpr uint pack_size = 16 / sizeof (scalar_t );
138
-
139
- const uint thr_offset = thread_idx * pack_size;
140
- const float p_scale = p_se / out_se;
141
- const float s_scale = s_se / out_se;
142
-
143
- if (thr_offset < head_size) {
144
- // Pack 128b load
145
- pack_128b_t p_out_pack = reinterpret_cast <const pack_128b_t *>(
146
- prefix_output_blk)[thr_offset / pack_size];
147
- pack_128b_t s_out_pack = reinterpret_cast <const pack_128b_t *>(
148
- suffix_output_blk)[thr_offset / pack_size];
149
- pack_128b_t o_out_pack;
150
-
151
- #pragma unroll
152
- for (uint i = 0 ; i < pack_size; ++i) {
153
- // Always use float for FMA to keep precision.
154
- // half(uint16_t), bfloat16, float -> float.
155
- const float p_out_f =
156
- to_float (reinterpret_cast <const scalar_t *>(&p_out_pack)[i]);
157
- const float s_out_f =
158
- to_float (reinterpret_cast <const scalar_t *>(&s_out_pack)[i]);
159
- // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
160
- const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
161
- // float -> half(uint16_t), bfloat16, float.
162
- from_float (reinterpret_cast <scalar_t *>(&o_out_pack)[i],
163
- o_out_f);
164
- }
165
-
166
- // Pack 128b storage
167
- reinterpret_cast <pack_128b_t *>(output_blk)[
168
- thr_offset / pack_size] = o_out_pack;
169
- }
152
+ merge_attn_states_per_thread<scalar_t >(
153
+ output, output_lse, prefix_output,
154
+ prefix_lse, suffix_output, suffix_lse,
155
+ num_tokens, num_heads, head_size,
156
+ token_idx, head_idx, thr_idx
157
+ );
170
158
}
171
159
}
172
160
@@ -183,9 +171,9 @@ __global__ void merge_attn_states_kernel(
183
171
} \
184
172
}
185
173
186
- #define LAUNCH_MERGE_ATTN_STATES (SCALAR_T, kLoopOverHead ) \
174
+ #define LAUNCH_MERGE_ATTN_STATES (SCALAR_T, kLoopOverHead, kFlattenOverHead ) \
187
175
{ \
188
- merge_attn_states_kernel<SCALAR_T, kLoopOverHead > \
176
+ merge_attn_states_kernel<SCALAR_T, kLoopOverHead , kFlattenOverHead > \
189
177
<<<grid, block>>> ( \
190
178
reinterpret_cast <SCALAR_T*>(output.data_ptr ()), output_lse_ptr, \
191
179
reinterpret_cast <SCALAR_T*>(prefix_output.data_ptr ()), \
@@ -217,20 +205,33 @@ void merge_attn_states_launcher(
217
205
if (output_lse.has_value ()) {
218
206
output_lse_ptr = output_lse.value ().data_ptr <float >();
219
207
}
208
+ // Keep threads num <= 512 per thread block.
209
+ const bool skip_flatten_over_head = (
210
+ (num_heads * head_size) / pack_size > 512 );
211
+
220
212
const bool skip_loop_over_head = (
221
- num_tokens <= 1024 || num_heads >= 64
222
- || disable_loop_over_head
213
+ disable_loop_over_head || num_tokens <= 1024 ||
214
+ (num_heads >= 64 && skip_flatten_over_head)
223
215
);
224
216
225
217
if (skip_loop_over_head) {
226
218
dim3 grid (num_tokens, num_heads);
227
219
dim3 block (head_size / pack_size);
228
- LAUNCH_MERGE_ATTN_STATES (SCALAR_T, false );
220
+ LAUNCH_MERGE_ATTN_STATES (SCALAR_T, false , false );
229
221
} else {
230
- // try loop over num heads for large NUM_TOKENS
231
- dim3 grid (num_tokens);
232
- dim3 block (head_size / pack_size);
233
- LAUNCH_MERGE_ATTN_STATES (SCALAR_T, true );
222
+ // try loop over num heads for large num_tokens
223
+ if (skip_flatten_over_head) {
224
+ dim3 grid (num_tokens);
225
+ dim3 block (head_size / pack_size);
226
+ LAUNCH_MERGE_ATTN_STATES (SCALAR_T, true , false );
227
+ } else {
228
+ // cases:
229
+ // num_tokens 8192, num_heads 16, head_size 128
230
+ // num_tokens 4096, num_heads 16, head_size 128
231
+ dim3 grid (num_tokens);
232
+ dim3 block ((num_heads * head_size) / pack_size);
233
+ LAUNCH_MERGE_ATTN_STATES (SCALAR_T, true , true );
234
+ }
234
235
}
235
236
}
236
237
0 commit comments