16
16
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17
17
#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18
18
#define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
19
+ // DS required for Online Softmax
20
+ struct __align__ (8 ) MD
21
+ {
22
+ float m;
23
+ float d;
24
+ };
19
25
20
26
// -------------------------------------- FP32 --------------------------------------
27
+ // Warp Reduce for Online Softmax
28
+
29
+ template <const int kWarpSize = WARP_SIZE >
30
+ __device__ __forceinline__ MD warp_reduce_md_op (MD value) {
31
+ unsigned int mask = 0xffffffff ;
32
+ #pragma unroll
33
+ for (int stride = kWarpSize >> 1 ; stride >= 1 ; stride >>= 1 ) {
34
+ MD other;
35
+ other.m = __shfl_xor_sync (mask, value.m , stride);
36
+ other.d = __shfl_xor_sync (mask, value.d , stride);
37
+
38
+ bool value_bigger = (value.m > other.m );
39
+ MD bigger_m = value_bigger ? value : other;
40
+ MD smaller_m = value_bigger ? other : value;
41
+
42
+ value.d = bigger_m.d + smaller_m.d * __expf (smaller_m.m - bigger_m.m );
43
+ value.m = bigger_m.m ;
44
+ }
45
+ return value;
46
+ }
47
+
21
48
// Warp Reduce Sum
22
49
template <const int kWarpSize = WARP_SIZE>
23
50
__device__ __forceinline__ float warp_reduce_sum_f32 (float val) {
@@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i
289
316
// TODO: support non 8-multiple K here
290
317
}
291
318
319
+ template <const int NUM_THREADS = 256 >
320
+ __global__ void online_softmax_f32_per_token_kernel (const float * x, float * y, int N) {
321
+
322
+ int local_tid = threadIdx .x ;
323
+ int global_tid = blockIdx .x * NUM_THREADS + threadIdx .x ;
324
+ const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
325
+ int warp_id = local_tid / WARP_SIZE;
326
+ int lane_id = local_tid % WARP_SIZE;
327
+ MD val;
328
+ val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
329
+ val.d = global_tid < N ? 1 .0f : 0 .0f ;
330
+
331
+ __shared__ MD shared[ WAPR_NUM ];
332
+ MD res = warp_reduce_md_op<WARP_SIZE>(val);
333
+
334
+ if (lane_id == 0 ) shared[warp_id] = res;
335
+ __syncthreads ();
336
+
337
+ if (local_tid < WARP_SIZE) {
338
+ MD block_res = shared[local_tid];
339
+ block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
340
+ if (local_tid == 0 ) {
341
+ shared[0 ] = block_res;
342
+ }
343
+ }
344
+ __syncthreads ();
345
+
346
+ MD final_res = shared[0 ];
347
+ float d_total_inverse = __fdividef (1 .0f , final_res.d );
348
+ if (global_tid < N) {
349
+ y[global_tid] = __expf (x[global_tid] - final_res.m ) * d_total_inverse;
350
+ }
351
+ }
352
+
292
353
// --------------------- PyTorch bindings for custom kernel -----------------------
293
354
#define STRINGFY (str ) #str
294
355
#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
440
501
break ; \
441
502
}
442
503
504
+ // online softmax per token
505
+ #define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (H ) \
506
+ online_softmax_f32_per_token_kernel<(H)><<<grid, block>>> ( \
507
+ reinterpret_cast <float *>(x.data_ptr()), \
508
+ reinterpret_cast <float *>(y.data_ptr()), \
509
+ N);
510
+
511
+ #define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (S, H ) \
512
+ dim3 block ((H)); \
513
+ dim3 grid ((S)); \
514
+ switch ((H)) \
515
+ { \
516
+ case 32 : \
517
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (32 ) \
518
+ break ; \
519
+ case 64 : \
520
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (64 ) \
521
+ break ; \
522
+ case 128 : \
523
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (128 ) \
524
+ break ; \
525
+ case 256 : \
526
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (256 ) \
527
+ break ; \
528
+ case 512 : \
529
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (512 ) \
530
+ break ; \
531
+ case 1024 : \
532
+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (1024 ) \
533
+ break ; \
534
+ default : \
535
+ throw std::runtime_error ( \
536
+ " only support H: 64/128/256/512/1024" ); \
537
+ break ; \
538
+ }
443
539
#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL (H ) \
444
540
safe_softmax_f32x4_per_token_kernel<(H)/4 ><<< \
445
541
grid, block>>> ( \
@@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
674
770
DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL (S, H)
675
771
}
676
772
773
+ void online_softmax_f32_per_token (torch::Tensor x, torch::Tensor y) {
774
+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kFloat32 )
775
+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kFloat32 )
776
+ CHECK_TORCH_TENSOR_SHAPE (x, y)
777
+ const int S = x.size (0 ); // seqlens
778
+ const int H = x.size (1 ); // head size/kv_len
779
+ const int N = S * H;
780
+ DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (S, H)
781
+ }
782
+
677
783
// grid memory fence fp32
678
784
TORCH_BINDING_SOFTMAX (f32 , torch::kFloat32 , float , 1 )
679
785
TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32 , float , 4 )
@@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
688
794
TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16_f32_per_token)
689
795
TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16x2_f32_per_token)
690
796
TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16x8_pack_f32_per_token)
797
+ TORCH_BINDING_COMMON_EXTENSION (online_softmax_f32_per_token)
691
798
}
0 commit comments