Skip to content

Commit 5ae3c08

Browse files
authored
[Softmax] Add online softmax according to Nvidia Paper (#60)
1 parent 3f5ace3 commit 5ae3c08

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

softmax/softmax.cu

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,35 @@
1616
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
1818
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
19+
// DS required for Online Softmax
20+
struct __align__(8) MD
21+
{
22+
float m;
23+
float d;
24+
};
1925

2026
// -------------------------------------- FP32 --------------------------------------
27+
// Warp Reduce for Online Softmax
28+
29+
template<const int kWarpSize = WARP_SIZE >
30+
__device__ __forceinline__ MD warp_reduce_md_op(MD value) {
31+
unsigned int mask = 0xffffffff;
32+
#pragma unroll
33+
for(int stride = kWarpSize >> 1; stride >= 1; stride >>= 1) {
34+
MD other;
35+
other.m = __shfl_xor_sync(mask, value.m, stride);
36+
other.d = __shfl_xor_sync(mask, value.d, stride);
37+
38+
bool value_bigger = (value.m > other.m);
39+
MD bigger_m = value_bigger ? value : other;
40+
MD smaller_m = value_bigger ? other : value;
41+
42+
value.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m);
43+
value.m = bigger_m.m;
44+
}
45+
return value;
46+
}
47+
2148
// Warp Reduce Sum
2249
template<const int kWarpSize = WARP_SIZE>
2350
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
@@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i
289316
// TODO: support non 8-multiple K here
290317
}
291318

319+
template<const int NUM_THREADS = 256 >
320+
__global__ void online_softmax_f32_per_token_kernel(const float* x, float* y, int N) {
321+
322+
int local_tid = threadIdx.x;
323+
int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x;
324+
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
325+
int warp_id = local_tid / WARP_SIZE;
326+
int lane_id = local_tid % WARP_SIZE;
327+
MD val;
328+
val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
329+
val.d = global_tid < N ? 1.0f : 0.0f;
330+
331+
__shared__ MD shared[ WAPR_NUM ];
332+
MD res = warp_reduce_md_op<WARP_SIZE>(val);
333+
334+
if (lane_id == 0) shared[warp_id] = res;
335+
__syncthreads();
336+
337+
if (local_tid < WARP_SIZE) {
338+
MD block_res = shared[local_tid];
339+
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
340+
if (local_tid == 0) {
341+
shared[0] = block_res;
342+
}
343+
}
344+
__syncthreads();
345+
346+
MD final_res = shared[0];
347+
float d_total_inverse = __fdividef(1.0f, final_res.d);
348+
if (global_tid < N) {
349+
y[global_tid] = __expf(x[global_tid] - final_res.m) * d_total_inverse;
350+
}
351+
}
352+
292353
// --------------------- PyTorch bindings for custom kernel -----------------------
293354
#define STRINGFY(str) #str
294355
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
440501
break; \
441502
}
442503

504+
// online softmax per token
505+
#define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \
506+
online_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
507+
reinterpret_cast<float*>(x.data_ptr()), \
508+
reinterpret_cast<float*>(y.data_ptr()), \
509+
N);
510+
511+
#define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \
512+
dim3 block((H)); \
513+
dim3 grid((S)); \
514+
switch ((H)) \
515+
{ \
516+
case 32: \
517+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \
518+
break; \
519+
case 64: \
520+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \
521+
break; \
522+
case 128: \
523+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(128) \
524+
break; \
525+
case 256: \
526+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(256) \
527+
break; \
528+
case 512: \
529+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(512) \
530+
break; \
531+
case 1024: \
532+
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(1024) \
533+
break; \
534+
default: \
535+
throw std::runtime_error( \
536+
"only support H: 64/128/256/512/1024"); \
537+
break; \
538+
}
443539
#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
444540
safe_softmax_f32x4_per_token_kernel<(H)/4><<< \
445541
grid, block>>>( \
@@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
674770
DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H)
675771
}
676772

773+
void online_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
774+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
775+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
776+
CHECK_TORCH_TENSOR_SHAPE(x, y)
777+
const int S = x.size(0); // seqlens
778+
const int H = x.size(1); // head size/kv_len
779+
const int N = S * H;
780+
DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
781+
}
782+
677783
// grid memory fence fp32
678784
TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
679785
TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
@@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
688794
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token)
689795
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token)
690796
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token)
797+
TORCH_BINDING_COMMON_EXTENSION(online_softmax_f32_per_token)
691798
}

softmax/softmax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
7777
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
7878
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
7979
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
80+
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
8081
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
8182

8283
print("-" * 100)
@@ -99,6 +100,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
99100
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
100101
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
101102
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
103+
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
102104
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
103105

104106
print("-" * 100)
@@ -121,6 +123,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
121123
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
122124
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
123125
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
126+
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
124127
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
125128

126129
print("-" * 100)

0 commit comments

Comments
 (0)