Skip to content

Commit ba6fac2

Browse files
authored
feat: optimize merge_attn_states thread block dispatch (#279)
* [Kernel] opt cuda merge_attn_states kernel, part-1 * kernel: optimize merge_attn_states cuda kernel
1 parent 9c74b75 commit ba6fac2

File tree

3 files changed

+136
-135
lines changed

3 files changed

+136
-135
lines changed

kernels/openai-triton/merge-attn-states/cuda_merge_attn_states.cu

Lines changed: 133 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ from_float(half& d, float s) { d = __float2half(s); }
1919
static __forceinline__ __device__ void
2020
from_float(__nv_bfloat16& d, float s) { d = __float2bfloat16(s); }
2121

22-
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
23-
// can be used to combine partial attention results (in the split-KV case)
24-
template <typename scalar_t, bool kLoopOverHead>
25-
__global__ void merge_attn_states_kernel(
22+
23+
template <typename scalar_t>
24+
__device__ __forceinline__ void merge_attn_states_per_thread(
2625
scalar_t* output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
2726
float* output_lse, // [NUM_HEADS, NUM_TOKENS]
2827
const scalar_t* __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
@@ -33,140 +32,129 @@ __global__ void merge_attn_states_kernel(
3332
const float* __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
3433
const uint num_tokens, // NUM_TOKENS
3534
const uint num_heads, // NUM QUERY HEADS
36-
const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
35+
const uint head_size, // HEAD_SIZE, 32,48,64,...,512,etc
36+
const uint token_idx,
37+
const uint head_idx,
38+
const uint thr_idx
3739
) {
38-
// TODO(DefTruth): may need to support fp8?
39-
if constexpr (kLoopOverHead) {
40-
// May loop over num heads for large NUM_TOKENS
41-
const uint token_idx = blockIdx.x;
42-
const uint thread_idx = threadIdx.x;
40+
using pack_128b_t = uint4; // float -> 4, half/bf16 -> 8
41+
constexpr uint pack_size = 16 / sizeof(scalar_t);
4342

44-
#pragma unroll
45-
for (uint head_idx = 0; head_idx < num_heads; ++head_idx) {
46-
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
47-
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
48-
p_lse =
49-
std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
50-
s_lse =
51-
std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
43+
const uint thr_offset = thr_idx * pack_size; // (0~15)*8, etc.
44+
const uint blk_offset =
45+
token_idx * num_heads * head_size + head_idx * head_size;
46+
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
47+
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
48+
scalar_t* output_blk = output + blk_offset;
5249

53-
const float max_lse = fmaxf(p_lse, s_lse);
54-
p_lse = p_lse - max_lse;
55-
s_lse = s_lse - max_lse;
56-
const float p_se = expf(p_lse);
57-
const float s_se = expf(s_lse);
58-
const float out_se = p_se + s_se;
50+
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
51+
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
52+
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
53+
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
5954

60-
if (output_lse != nullptr) {
61-
float out_lse = logf(out_se) + max_lse;
62-
output_lse[head_idx * num_tokens + token_idx] = out_lse;
63-
}
55+
const float max_lse = fmaxf(p_lse, s_lse);
56+
p_lse = p_lse - max_lse;
57+
s_lse = s_lse - max_lse;
58+
const float p_se = expf(p_lse);
59+
const float s_se = expf(s_lse);
60+
const float out_se = p_se + s_se;
61+
const float p_scale = p_se / out_se;
62+
const float s_scale = s_se / out_se;
6463

65-
const uint blk_offset =
66-
token_idx * num_heads * head_size + head_idx * head_size;
67-
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
68-
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
69-
scalar_t* output_blk = output + blk_offset;
64+
// We only need to write to output_lse once per head.
65+
if (output_lse != nullptr && thr_idx == 0) {
66+
float out_lse = logf(out_se) + max_lse;
67+
output_lse[head_idx * num_tokens + token_idx] = out_lse;
68+
}
7069

71-
// float -> 4, half/bf16 -> 8
72-
using pack_128b_t = uint4;
73-
constexpr uint pack_size = 16 / sizeof(scalar_t);
70+
if (thr_offset < head_size) {
71+
// Pack 128b load
72+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
73+
prefix_output_blk)[thr_offset / pack_size];
74+
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
75+
suffix_output_blk)[thr_offset / pack_size];
76+
pack_128b_t o_out_pack;
7477

75-
const uint thr_offset = thread_idx * pack_size;
76-
const float p_scale = p_se / out_se;
77-
const float s_scale = s_se / out_se;
78+
#pragma unroll
79+
for (uint i = 0; i < pack_size; ++i) {
80+
// Always use float for FMA to keep precision.
81+
// half(uint16_t), bfloat16, float -> float.
82+
const float p_out_f =
83+
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
84+
const float s_out_f =
85+
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
86+
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
87+
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
88+
// float -> half(uint16_t), bfloat16, float.
89+
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
90+
o_out_f);
91+
}
7892

79-
if (thr_offset < head_size) {
80-
// Pack 128b load
81-
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
82-
prefix_output_blk)[thr_offset / pack_size];
83-
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
84-
suffix_output_blk)[thr_offset / pack_size];
85-
pack_128b_t o_out_pack;
93+
// Pack 128b storage
94+
reinterpret_cast<pack_128b_t*>(output_blk)[
95+
thr_offset / pack_size] = o_out_pack;
96+
}
97+
}
8698

87-
#pragma unroll
88-
for (uint i = 0; i < pack_size; ++i) {
89-
// Always use float for FMA to keep precision.
90-
// half(uint16_t), bfloat16, float -> float.
91-
const float p_out_f =
92-
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
93-
const float s_out_f =
94-
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
95-
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
96-
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
97-
// float -> half(uint16_t), bfloat16, float.
98-
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
99-
o_out_f);
100-
}
99+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
100+
// can be used to combine partial attention results (in the split-KV case)
101+
template <typename scalar_t, bool kLoopOverHead, bool kFlattenOverHead = false>
102+
__global__ void merge_attn_states_kernel(
103+
scalar_t* output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
104+
float* output_lse, // [NUM_HEADS, NUM_TOKENS]
105+
const scalar_t* __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
106+
// HEAD_SIZE]
107+
const float* __restrict__ prefix_lse, // [NUM_HEADS, NUM_TOKENS]
108+
const scalar_t* __restrict__ suffix_output, // [NUM_TOKENS, NUM_HEADS,
109+
// HEAD_SIZE]
110+
const float* __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
111+
const uint num_tokens, // NUM_TOKENS
112+
const uint num_heads, // NUM QUERY HEADS
113+
const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
114+
) {
115+
if constexpr (kLoopOverHead) {
116+
// May loop over num heads for large num_tokens
117+
const uint token_idx = blockIdx.x;
118+
const uint thread_idx = threadIdx.x;
119+
120+
if constexpr (kFlattenOverHead) {
121+
// thread num = (num_heads * head_size) / pack_size
122+
// = num_heads * (head_size / pack_size), 16 * (128 / 8)
123+
// tid: 0~255, 0~15->head 0, 16~31->head 1, ..., etc.
124+
constexpr uint pack_size = 16 / sizeof(scalar_t);
125+
const uint head_idx = thread_idx / (head_size / pack_size);
126+
const uint thr_idx = thread_idx % (head_size / pack_size);
127+
merge_attn_states_per_thread<scalar_t>(
128+
output, output_lse, prefix_output,
129+
prefix_lse, suffix_output, suffix_lse,
130+
num_tokens, num_heads, head_size,
131+
token_idx, head_idx, thr_idx
132+
);
133+
} else {
134+
const uint thr_idx = thread_idx;
135+
#pragma unroll
136+
for (uint head_idx = 0; head_idx < num_heads; ++head_idx) {
137+
merge_attn_states_per_thread<scalar_t>(
138+
output, output_lse, prefix_output,
139+
prefix_lse, suffix_output, suffix_lse,
140+
num_tokens, num_heads, head_size,
141+
token_idx, head_idx, thr_idx
142+
);
143+
} // End loop over heads
144+
} // End kFlattenOverHead
101145

102-
// Pack 128b storage
103-
reinterpret_cast<pack_128b_t*>(output_blk)[
104-
thr_offset / pack_size] = o_out_pack;
105-
}
106-
} // End loop over heads
107146
} else {
108147
const uint token_idx = blockIdx.x;
109148
const uint head_idx = blockIdx.y;
110149
const uint thread_idx = threadIdx.x;
150+
const uint thr_idx = thread_idx;
111151

112-
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
113-
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
114-
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
115-
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
116-
117-
const float max_lse = fmaxf(p_lse, s_lse);
118-
p_lse = p_lse - max_lse;
119-
s_lse = s_lse - max_lse;
120-
const float p_se = expf(p_lse);
121-
const float s_se = expf(s_lse);
122-
const float out_se = p_se + s_se;
123-
124-
if (output_lse != nullptr) {
125-
float out_lse = logf(out_se) + max_lse;
126-
output_lse[head_idx * num_tokens + token_idx] = out_lse;
127-
}
128-
129-
const uint blk_offset =
130-
token_idx * num_heads * head_size + head_idx * head_size;
131-
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
132-
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
133-
scalar_t* output_blk = output + blk_offset;
134-
135-
// float -> 4, half/bf16 -> 8
136-
using pack_128b_t = uint4; // 16 bytes
137-
constexpr uint pack_size = 16 / sizeof(scalar_t);
138-
139-
const uint thr_offset = thread_idx * pack_size;
140-
const float p_scale = p_se / out_se;
141-
const float s_scale = s_se / out_se;
142-
143-
if (thr_offset < head_size) {
144-
// Pack 128b load
145-
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
146-
prefix_output_blk)[thr_offset / pack_size];
147-
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
148-
suffix_output_blk)[thr_offset / pack_size];
149-
pack_128b_t o_out_pack;
150-
151-
#pragma unroll
152-
for (uint i = 0; i < pack_size; ++i) {
153-
// Always use float for FMA to keep precision.
154-
// half(uint16_t), bfloat16, float -> float.
155-
const float p_out_f =
156-
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
157-
const float s_out_f =
158-
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
159-
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
160-
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
161-
// float -> half(uint16_t), bfloat16, float.
162-
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
163-
o_out_f);
164-
}
165-
166-
// Pack 128b storage
167-
reinterpret_cast<pack_128b_t*>(output_blk)[
168-
thr_offset / pack_size] = o_out_pack;
169-
}
152+
merge_attn_states_per_thread<scalar_t>(
153+
output, output_lse, prefix_output,
154+
prefix_lse, suffix_output, suffix_lse,
155+
num_tokens, num_heads, head_size,
156+
token_idx, head_idx, thr_idx
157+
);
170158
}
171159
}
172160

@@ -183,9 +171,9 @@ __global__ void merge_attn_states_kernel(
183171
} \
184172
}
185173

186-
#define LAUNCH_MERGE_ATTN_STATES(SCALAR_T, kLoopOverHead) \
174+
#define LAUNCH_MERGE_ATTN_STATES(SCALAR_T, kLoopOverHead, kFlattenOverHead) \
187175
{ \
188-
merge_attn_states_kernel<SCALAR_T, kLoopOverHead> \
176+
merge_attn_states_kernel<SCALAR_T, kLoopOverHead, kFlattenOverHead> \
189177
<<<grid, block>>>( \
190178
reinterpret_cast<SCALAR_T*>(output.data_ptr()), output_lse_ptr, \
191179
reinterpret_cast<SCALAR_T*>(prefix_output.data_ptr()), \
@@ -217,20 +205,33 @@ void merge_attn_states_launcher(
217205
if (output_lse.has_value()) {
218206
output_lse_ptr = output_lse.value().data_ptr<float>();
219207
}
208+
// Keep threads num <= 512 per thread block.
209+
const bool skip_flatten_over_head = (
210+
(num_heads * head_size) / pack_size > 512);
211+
220212
const bool skip_loop_over_head = (
221-
num_tokens <= 1024 || num_heads >= 64
222-
|| disable_loop_over_head
213+
disable_loop_over_head || num_tokens <= 1024 ||
214+
(num_heads >= 64 && skip_flatten_over_head)
223215
);
224216

225217
if (skip_loop_over_head) {
226218
dim3 grid(num_tokens, num_heads);
227219
dim3 block(head_size / pack_size);
228-
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, false);
220+
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, false, false);
229221
} else {
230-
// try loop over num heads for large NUM_TOKENS
231-
dim3 grid(num_tokens);
232-
dim3 block(head_size / pack_size);
233-
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true);
222+
// try loop over num heads for large num_tokens
223+
if (skip_flatten_over_head) {
224+
dim3 grid(num_tokens);
225+
dim3 block(head_size / pack_size);
226+
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true, false);
227+
} else {
228+
// cases:
229+
// num_tokens 8192, num_heads 16, head_size 128
230+
// num_tokens 4096, num_heads 16, head_size 128
231+
dim3 grid(num_tokens);
232+
dim3 block((num_heads * head_size) / pack_size);
233+
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true, true);
234+
}
234235
}
235236
}
236237

kernels/openai-triton/merge-attn-states/cuda_merge_attn_states.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
1616
"--expt-relaxed-constexpr",
1717
"--expt-extended-lambda",
18-
"--use_fast_math"
18+
# "--use_fast_math"
1919
],
2020
extra_cflags=['-std=c++17'],
2121
verbose=True

kernels/openai-triton/merge-attn-states/test_merge_attn_states.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def merge_attn_states_torch(
4242
return output, output_lse
4343

4444

45-
NUM_TOKENS = [256, 512, 613, 1024, 1536, 4096]
45+
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
4646
NUM_QUERY_HEADS = [4, 8, 16, 32]
4747
HEAD_SIZES = [64, 96, 128]
4848
DTYPES = [torch.float32, torch.half, torch.bfloat16]
4949

50-
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
50+
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
5151
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
5252
@pytest.mark.parametrize("head_size", HEAD_SIZES)
5353
@pytest.mark.parametrize("output_dtype", DTYPES)

0 commit comments

Comments
 (0)