|
1 | 1 | // Copyright (c) OpenMMLab. All rights reserved.
|
2 | 2 |
|
| 3 | +#include <stdexcept> |
| 4 | + |
3 | 5 | #include "cub/block/block_reduce.cuh"
|
4 | 6 |
|
5 | 7 | #include "src/turbomind/kernels/core/array_ops.h"
|
6 | 8 | #include "src/turbomind/kernels/core/common.h"
|
| 9 | +#include "src/turbomind/kernels/core/math.h" |
| 10 | +#include "src/turbomind/kernels/core/meta.h" |
| 11 | + |
7 | 12 | #include "src/turbomind/kernels/norm/rms_norm.h"
|
| 13 | +#include "src/turbomind/utils/Tensor.h" |
8 | 14 |
|
9 | 15 | namespace turbomind {
|
10 | 16 |
|
@@ -115,6 +121,104 @@ template void invokeRMSNorm(nv_bfloat16* dst,
|
115 | 121 | cudaStream_t st);
|
116 | 122 | #endif
|
117 | 123 |
|
| 124 | +template<class T, class A, int vec_size, int max_dim> |
| 125 | +__global__ void QkRMSNormKernel(T* data, // |
| 126 | + int ld, |
| 127 | + const T* weight, |
| 128 | + int dim, |
| 129 | + int n, |
| 130 | + int token_num, |
| 131 | + float eps, |
| 132 | + float inv_dim) |
| 133 | +{ |
| 134 | + static_assert((max_dim & (max_dim - 1)) == 0); |
| 135 | + |
| 136 | + constexpr int thr_per_qk = max_dim / vec_size; |
| 137 | + |
| 138 | + const int bi = (threadIdx.x + blockIdx.x * blockDim.x) / thr_per_qk; |
| 139 | + const int di = threadIdx.x % thr_per_qk * vec_size; |
| 140 | + const int ti = bi / n; |
| 141 | + const int hi = bi % n; |
| 142 | + |
| 143 | + if (bi >= token_num * n) { |
| 144 | + return; |
| 145 | + } |
| 146 | + |
| 147 | + data += ti * ld + hi * dim; |
| 148 | + |
| 149 | + Array<T, vec_size> vec{}; |
| 150 | + if (di < dim) { |
| 151 | + Load(vec, &data[di]); |
| 152 | + } |
| 153 | + |
| 154 | + using namespace ops; |
| 155 | + auto acc = cast<A>(vec); |
| 156 | + acc = acc * acc; |
| 157 | + |
| 158 | + float sum{}; |
| 159 | + PRAGMA_UNROLL |
| 160 | + for (int i = 0; i < vec_size; ++i) { |
| 161 | + sum += acc[i]; |
| 162 | + } |
| 163 | + |
| 164 | + PRAGMA_UNROLL |
| 165 | + for (int mask = thr_per_qk / 2; mask >= 1; mask /= 2) { |
| 166 | + sum += __shfl_xor_sync((uint32_t)-1, sum, mask); |
| 167 | + } |
| 168 | + |
| 169 | + sum = rsqrtf(sum * inv_dim + eps); |
| 170 | + |
| 171 | + Array<T, vec_size> w; |
| 172 | + if (di < dim) { |
| 173 | + Ldg(w, &weight[di]); |
| 174 | + PRAGMA_UNROLL |
| 175 | + for (int i = 0; i < vec_size; ++i) { |
| 176 | + vec[i] = (T)((float)vec[i] * sum) * w[i]; |
| 177 | + } |
| 178 | + Store(&data[di], vec); |
| 179 | + } |
| 180 | +} |
| 181 | + |
| 182 | +void invokeQkRMSNorm(void* data, |
| 183 | + int ld, |
| 184 | + const void* weight, |
| 185 | + DataType dtype, |
| 186 | + int head_dim, |
| 187 | + int n, |
| 188 | + int token_num, |
| 189 | + float eps, |
| 190 | + cudaStream_t stream) |
| 191 | +{ |
| 192 | + auto invoke = [&](auto t, auto max_dim_t) { |
| 193 | + using T = decltype(t); |
| 194 | + |
| 195 | + constexpr int vec_size = sizeof(uint4) / sizeof(T); |
| 196 | + constexpr int max_dim = max_dim_t.value; |
| 197 | + constexpr int thr_per_qk = max_dim / vec_size; |
| 198 | + |
| 199 | + FT_CHECK(head_dim % vec_size == 0); |
| 200 | + |
| 201 | + const int threads = thr_per_qk * n * (int64_t)token_num; |
| 202 | + const int block_dim = 512; |
| 203 | + const int grid_dim = cdiv(threads, block_dim); |
| 204 | + |
| 205 | + QkRMSNormKernel<T, float, vec_size, max_dim><<<grid_dim, block_dim, 0, stream>>>( |
| 206 | + (T*)data, ld, (const T*)weight, head_dim, n, token_num, eps, 1.f / head_dim); |
| 207 | + }; |
| 208 | + |
| 209 | + constexpr constant<128> max_dim{}; |
| 210 | + FT_CHECK(head_dim <= max_dim); |
| 211 | + |
| 212 | + switch (dtype) { |
| 213 | + case TYPE_FP16: |
| 214 | + return invoke(half{}, max_dim); |
| 215 | + case TYPE_BF16: |
| 216 | + return invoke(nv_bfloat16{}, max_dim); |
| 217 | + default: |
| 218 | + throw std::runtime_error("not implemented"); |
| 219 | + } |
| 220 | +} |
| 221 | + |
118 | 222 | // r' <- r + (h + b)
|
119 | 223 | // h' <- norm(r') * w
|
120 | 224 | template<class T, class Tacc, int block_dim, int vec_size>
|
|
0 commit comments