Skip to content

Commit 8c988fa

Browse files
authored
CUDA: add fused rms norm (#14800)
1 parent acd6cb1 commit 8c988fa

File tree

4 files changed

+144
-9
lines changed

4 files changed

+144
-9
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include <cstddef>
5656
#include <cstdint>
5757
#include <float.h>
58+
#include <initializer_list>
5859
#include <limits>
5960
#include <map>
6061
#include <memory>
@@ -2765,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27652766
}
27662767
#endif
27672768

2769+
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2770+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2771+
return false;
2772+
}
2773+
2774+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2775+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2776+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2777+
2778+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2779+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2780+
2781+
//rms norm only supports F32
2782+
if (mul->src[0]->type != GGML_TYPE_F32 ||
2783+
mul->src[1]->type != GGML_TYPE_F32 ||
2784+
mul->type != GGML_TYPE_F32) {
2785+
return false;
2786+
}
2787+
2788+
//if rms norm is the B operand, then we don't handle broadcast
2789+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2790+
return false;
2791+
}
2792+
2793+
//rms_norm kernel assumes contigous rows
2794+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2795+
return false;
2796+
}
2797+
}
2798+
2799+
return true;
2800+
}
2801+
27682802
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
27692803
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
27702804
// flag used to determine whether it is an integrated_gpu
@@ -2774,13 +2808,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
27742808
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
27752809
// With the use of CUDA graphs, the execution will be performed by the graph launch.
27762810
if (!use_cuda_graph || cuda_graph_update_required) {
2811+
27772812
for (int i = 0; i < cgraph->n_nodes; i++) {
27782813
ggml_tensor * node = cgraph->nodes[i];
27792814

27802815
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
27812816
continue;
27822817
}
27832818

2819+
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2820+
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2821+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2822+
i++;
2823+
continue;
2824+
}
27842825
#ifndef NDEBUG
27852826
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
27862827
for (int j = 0; j < GGML_MAX_SRC; j++) {

ggml/src/ggml-cuda/norm.cu

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104104
}
105105
}
106106

107-
template <int block_size>
107+
template <int block_size, bool do_multiply = false>
108108
static __global__ void rms_norm_f32(
109109
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110-
const int64_t stride_sample, const float eps) {
110+
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
111+
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
112+
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
111113
const int nrows = gridDim.x;
112114
const int nchannels = gridDim.y;
113115

@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
119121
x += sample*stride_sample + channel*stride_channel + row*stride_row;
120122
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121123

124+
if constexpr (do_multiply) {
125+
const int mul_row = row % mul_nrows;
126+
const int mul_channel = channel % mul_nchannels;
127+
const int mul_sample = sample % mul_nsamples;
128+
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
129+
}
130+
122131
float tmp = 0.0f; // partial sum for thread in warp
123132

124133
for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
145154
const float scale = rsqrtf(mean + eps);
146155

147156
for (int col = tid; col < ncols; col += block_size) {
148-
dst[col] = scale * x[col];
157+
if constexpr (do_multiply) {
158+
const int mul_col = col % mul_ncols;
159+
dst[col] = scale * x[col] * mul[mul_col];
160+
} else {
161+
dst[col] = scale * x[col];
162+
}
149163
}
150164
}
151165

@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
310324
const dim3 blocks_num(nrows, nchannels, nsamples);
311325
if (ncols < 1024) {
312326
const dim3 block_dims(WARP_SIZE, 1, 1);
313-
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
327+
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
328+
} else {
329+
const dim3 block_dims(1024, 1, 1);
330+
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
331+
}
332+
}
333+
334+
static void rms_norm_mul_f32_cuda(
335+
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
337+
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
338+
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
339+
const float eps, cudaStream_t stream) {
340+
const dim3 blocks_num(nrows, nchannels, nsamples);
341+
if (mul == nullptr) {
342+
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
343+
return;
344+
}
345+
if (ncols < 1024) {
346+
const dim3 block_dims(WARP_SIZE, 1, 1);
347+
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
314348
} else {
315349
const dim3 block_dims(1024, 1, 1);
316-
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350+
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
317351
}
318352
}
319353

@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407441
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408442
}
409443

444+
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
445+
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
446+
float eps = 0.0f;
447+
448+
memcpy(&eps, dst->op_params, sizeof(float));
449+
450+
const float * src0_d = (const float *) rms_norm_src->data;
451+
const float * mul_d = nullptr;
452+
const ggml_tensor * mul_src = nullptr;
453+
454+
if (mul_tensor->src[0] == dst) {
455+
mul_d = (float *) mul_tensor->src[1]->data;
456+
mul_src = mul_tensor->src[1];
457+
} else if(mul_tensor->src[1] == dst) {
458+
mul_d = (float *) mul_tensor->src[0]->data;
459+
mul_src = mul_tensor->src[0];
460+
} else {
461+
GGML_ASSERT(false);
462+
}
463+
464+
float * dst_d = (float *) mul_tensor->data;
465+
cudaStream_t stream = ctx.stream();
466+
467+
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
468+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
469+
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
470+
GGML_ASSERT(eps >= 0.0f);
471+
472+
const int64_t ne00 = rms_norm_src->ne[0];
473+
const int64_t ne01 = rms_norm_src->ne[1];
474+
const int64_t ne02 = rms_norm_src->ne[2];
475+
const int64_t ne03 = rms_norm_src->ne[3];
476+
477+
const size_t ts0 = ggml_type_size(rms_norm_src->type);
478+
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
479+
const int64_t s01 = rms_norm_src->nb[1] / ts0;
480+
const int64_t s02 = rms_norm_src->nb[2] / ts0;
481+
const int64_t s03 = rms_norm_src->nb[3] / ts0;
482+
483+
const size_t ts_mul = ggml_type_size(mul_src->type);
484+
GGML_ASSERT(mul_src->nb[0] == ts_mul);
485+
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
486+
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
487+
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
488+
489+
const int mul_ncols = mul_src->ne[0];
490+
const int mul_nrows = mul_src->ne[1];
491+
const int mul_nchannels = mul_src->ne[2];
492+
const int mul_nsamples = mul_src->ne[3];
493+
494+
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
495+
}
496+
410497
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411498
const ggml_tensor * grad = dst->src[0]; // gradients
412499
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass

ggml/src/ggml-cuda/norm.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
66

77
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
88

9+
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
10+
911
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1012

1113
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
26412641
const ggml_type type;
26422642
const std::array<int64_t, 4> ne;
26432643
const float eps;
2644+
const bool broadcast;
26442645

26452646
std::string op_desc(ggml_tensor * t) override {
26462647
GGML_UNUSED(t);
@@ -2650,18 +2651,21 @@ struct test_rms_norm_mul_add : public test_case {
26502651
bool run_whole_graph() override { return true; }
26512652

26522653
std::string vars() override {
2653-
return VARS_TO_STR3(type, ne, eps);
2654+
return VARS_TO_STR4(type, ne, eps, broadcast);
26542655
}
26552656

26562657
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
26572658
std::array<int64_t, 4> ne = {64, 5, 4, 3},
2658-
float eps = 1e-6f)
2659-
: type(type), ne(ne), eps(eps) {}
2659+
float eps = 1e-6f, bool broadcast = false)
2660+
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
26602661

26612662
ggml_tensor * build_graph(ggml_context * ctx) override {
2662-
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2663+
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
2664+
2665+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
26632666
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
26642667
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
2668+
26652669
ggml_set_param(a);
26662670
ggml_set_name(a, "a");
26672671
ggml_set_param(b);
@@ -5354,6 +5358,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53545358
}
53555359
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
53565360
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
5361+
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
53575362
}
53585363

53595364
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));

0 commit comments

Comments
 (0)