Skip to content

Commit 4160b93

Browse files
MollySophiaggerganov
authored andcommitted
ggml : add epsilon as a parameter for group_norm (llama/8818)
Signed-off-by: Molly Sophia <[email protected]>
1 parent 7a96e66 commit 4160b93

File tree

5 files changed

+30
-20
lines changed

5 files changed

+30
-20
lines changed

ggml/include/ggml.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -1140,16 +1140,17 @@ extern "C" {
11401140

11411141
// group normalize along ne0*ne1*n_groups
11421142
// used in stable-diffusion
1143-
// TODO: eps is hardcoded to 1e-6 for now
11441143
GGML_API struct ggml_tensor * ggml_group_norm(
11451144
struct ggml_context * ctx,
11461145
struct ggml_tensor * a,
1147-
int n_groups);
1146+
int n_groups,
1147+
float eps);
11481148

11491149
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
11501150
struct ggml_context * ctx,
11511151
struct ggml_tensor * a,
1152-
int n_groups);
1152+
int n_groups,
1153+
float eps);
11531154

11541155
// a - x
11551156
// b - dy

ggml/src/ggml-cuda/norm.cu

+6-3
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
142142
}
143143
}
144144

145-
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
146-
static const float eps = 1e-6f;
145+
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
147146
if (group_size < 1024) {
148147
const dim3 block_dims(WARP_SIZE, 1, 1);
149148
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
196195
GGML_ASSERT( dst->type == GGML_TYPE_F32);
197196

198197
int num_groups = dst->op_params[0];
198+
199+
float eps;
200+
memcpy(&eps, dst->op_params + 1, sizeof(float));
201+
199202
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
200-
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
203+
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
201204
}
202205

203206
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-metal.m

+2-4
Original file line numberDiff line numberDiff line change
@@ -2236,10 +2236,8 @@ static enum ggml_status ggml_metal_graph_compute(
22362236
GGML_ASSERT(ne00 % 4 == 0);
22372237
GGML_ASSERT(ggml_is_contiguous(src0));
22382238

2239-
//float eps;
2240-
//memcpy(&eps, dst->op_params, sizeof(float));
2241-
2242-
const float eps = 1e-6f; // TODO: temporarily hardcoded
2239+
float eps;
2240+
memcpy(&eps, dst->op_params + 1, sizeof(float));
22432241

22442242
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
22452243

ggml/src/ggml-sycl/norm.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
225225
}
226226

227227
static void group_norm_f32_sycl(const float* x, float* dst,
228-
const int num_groups, const int group_size,
228+
const int num_groups, const float eps, const int group_size,
229229
const int ne_elements, queue_ptr stream, int device) {
230-
static const float eps = 1e-6f;
231230
if (group_size < 1024) {
232231
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233232
stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
343342
GGML_ASSERT(dst->type == GGML_TYPE_F32);
344343

345344
int num_groups = dst->op_params[0];
345+
346+
float eps;
347+
memcpy(&eps, dst->op_params + 1, sizeof(float));
348+
346349
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347-
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
350+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348351

349352
(void)src1;
350353
(void)dst;

ggml/src/ggml.c

+12-7
Original file line numberDiff line numberDiff line change
@@ -5377,6 +5377,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
53775377
struct ggml_context * ctx,
53785378
struct ggml_tensor * a,
53795379
int n_groups,
5380+
float eps,
53805381
bool inplace) {
53815382

53825383
bool is_node = false;
@@ -5387,7 +5388,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
53875388

53885389
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
53895390

5390-
result->op_params[0] = n_groups;
5391+
ggml_set_op_params_i32(result, 0, n_groups);
5392+
ggml_set_op_params_f32(result, 1, eps);
53915393

53925394
result->op = GGML_OP_GROUP_NORM;
53935395
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5399,15 +5401,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
53995401
struct ggml_tensor * ggml_group_norm(
54005402
struct ggml_context * ctx,
54015403
struct ggml_tensor * a,
5402-
int n_groups) {
5403-
return ggml_group_norm_impl(ctx, a, n_groups, false);
5404+
int n_groups,
5405+
float eps) {
5406+
return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
54045407
}
54055408

54065409
struct ggml_tensor * ggml_group_norm_inplace(
54075410
struct ggml_context * ctx,
54085411
struct ggml_tensor * a,
5409-
int n_groups) {
5410-
return ggml_group_norm_impl(ctx, a, n_groups, true);
5412+
int n_groups,
5413+
float eps) {
5414+
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
54115415
}
54125416

54135417
// ggml_mul_mat
@@ -12098,10 +12102,11 @@ static void ggml_compute_forward_group_norm_f32(
1209812102

1209912103
GGML_TENSOR_UNARY_OP_LOCALS
1210012104

12101-
const float eps = 1e-6f; // TODO: make this a parameter
12102-
1210312105
// TODO: optimize
1210412106

12107+
float eps;
12108+
memcpy(&eps, dst->op_params + 1, sizeof(float));
12109+
1210512110
int n_channels = src0->ne[2];
1210612111
int n_groups = dst->op_params[0];
1210712112
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;

0 commit comments

Comments
 (0)