@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104
104
}
105
105
}
106
106
107
- template <int block_size>
107
+ template <int block_size, bool do_multiply = false >
108
108
static __global__ void rms_norm_f32 (
109
109
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110
- const int64_t stride_sample, const float eps) {
110
+ const int64_t stride_sample, const float eps, const float * mul = nullptr , const int64_t mul_stride_row = 0 ,
111
+ const int64_t mul_stride_channel = 0 , const int64_t mul_stride_sample = 0 , const int mul_ncols = 0 ,
112
+ const int mul_nrows = 0 , const int mul_nchannels = 0 , const int mul_nsamples = 0 ) {
111
113
const int nrows = gridDim .x ;
112
114
const int nchannels = gridDim .y ;
113
115
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
119
121
x += sample*stride_sample + channel*stride_channel + row*stride_row;
120
122
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121
123
124
+ if constexpr (do_multiply) {
125
+ const int mul_row = row % mul_nrows;
126
+ const int mul_channel = channel % mul_nchannels;
127
+ const int mul_sample = sample % mul_nsamples;
128
+ mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
129
+ }
130
+
122
131
float tmp = 0 .0f ; // partial sum for thread in warp
123
132
124
133
for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
145
154
const float scale = rsqrtf (mean + eps);
146
155
147
156
for (int col = tid; col < ncols; col += block_size) {
148
- dst[col] = scale * x[col];
157
+ if constexpr (do_multiply) {
158
+ const int mul_col = col % mul_ncols;
159
+ dst[col] = scale * x[col] * mul[mul_col];
160
+ } else {
161
+ dst[col] = scale * x[col];
162
+ }
149
163
}
150
164
}
151
165
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
310
324
const dim3 blocks_num (nrows, nchannels, nsamples);
311
325
if (ncols < 1024 ) {
312
326
const dim3 block_dims (WARP_SIZE, 1 , 1 );
313
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
327
+ rms_norm_f32<WARP_SIZE, false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
328
+ } else {
329
+ const dim3 block_dims (1024 , 1 , 1 );
330
+ rms_norm_f32<1024 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
331
+ }
332
+ }
333
+
334
+ static void rms_norm_mul_f32_cuda (
335
+ const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
337
+ const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
338
+ const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
339
+ const float eps, cudaStream_t stream) {
340
+ const dim3 blocks_num (nrows, nchannels, nsamples);
341
+ if (mul == nullptr ) {
342
+ rms_norm_f32_cuda (x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
343
+ return ;
344
+ }
345
+ if (ncols < 1024 ) {
346
+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
347
+ rms_norm_f32<WARP_SIZE, true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
314
348
} else {
315
349
const dim3 block_dims (1024 , 1 , 1 );
316
- rms_norm_f32<1024 ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350
+ rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples );
317
351
}
318
352
}
319
353
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407
441
rms_norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408
442
}
409
443
444
+ void ggml_cuda_op_rms_norm_fused (ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
445
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src [0 ];
446
+ float eps = 0 .0f ;
447
+
448
+ memcpy (&eps, dst->op_params , sizeof (float ));
449
+
450
+ const float * src0_d = (const float *) rms_norm_src->data ;
451
+ const float * mul_d = nullptr ;
452
+ const ggml_tensor * mul_src = nullptr ;
453
+
454
+ if (mul_tensor->src [0 ] == dst) {
455
+ mul_d = (float *) mul_tensor->src [1 ]->data ;
456
+ mul_src = mul_tensor->src [1 ];
457
+ } else if (mul_tensor->src [1 ] == dst) {
458
+ mul_d = (float *) mul_tensor->src [0 ]->data ;
459
+ mul_src = mul_tensor->src [0 ];
460
+ } else {
461
+ GGML_ASSERT (false );
462
+ }
463
+
464
+ float * dst_d = (float *) mul_tensor->data ;
465
+ cudaStream_t stream = ctx.stream ();
466
+
467
+ GGML_ASSERT (rms_norm_src->type == GGML_TYPE_F32);
468
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
469
+ GGML_ASSERT (mul_tensor->type == GGML_TYPE_F32);
470
+ GGML_ASSERT (eps >= 0 .0f );
471
+
472
+ const int64_t ne00 = rms_norm_src->ne [0 ];
473
+ const int64_t ne01 = rms_norm_src->ne [1 ];
474
+ const int64_t ne02 = rms_norm_src->ne [2 ];
475
+ const int64_t ne03 = rms_norm_src->ne [3 ];
476
+
477
+ const size_t ts0 = ggml_type_size (rms_norm_src->type );
478
+ GGML_ASSERT (rms_norm_src->nb [0 ] == ts0);
479
+ const int64_t s01 = rms_norm_src->nb [1 ] / ts0;
480
+ const int64_t s02 = rms_norm_src->nb [2 ] / ts0;
481
+ const int64_t s03 = rms_norm_src->nb [3 ] / ts0;
482
+
483
+ const size_t ts_mul = ggml_type_size (mul_src->type );
484
+ GGML_ASSERT (mul_src->nb [0 ] == ts_mul);
485
+ const int64_t mul_s01 = mul_src->nb [1 ] / ts_mul;
486
+ const int64_t mul_s02 = mul_src->nb [2 ] / ts_mul;
487
+ const int64_t mul_s03 = mul_src->nb [3 ] / ts_mul;
488
+
489
+ const int mul_ncols = mul_src->ne [0 ];
490
+ const int mul_nrows = mul_src->ne [1 ];
491
+ const int mul_nchannels = mul_src->ne [2 ];
492
+ const int mul_nsamples = mul_src->ne [3 ];
493
+
494
+ rms_norm_mul_f32_cuda (src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
495
+ }
496
+
410
497
void ggml_cuda_op_rms_norm_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411
498
const ggml_tensor * grad = dst->src [0 ]; // gradients
412
499
const ggml_tensor * src0f = dst->src [1 ]; // src0 from forward pass
0 commit comments