Skip to content

Commit 189504e

Browse files
committed
SGD op param store weight-decay and not 1-alpha*wd
1 parent 71ffb4b commit 189504e

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10387,7 +10387,7 @@ void ggml_compute_forward_opt_step_adamw(
1038710387
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
1038810388
const ggml_tensor * src0 = dst->src[0];
1038910389
const ggml_tensor * src0_grad = dst->src[1];
10390-
const ggml_tensor * sgd_params = dst->src[2];
10390+
const ggml_tensor * sgd_params = dst->src[2];
1039110391

1039210392
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
1039310393
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
@@ -10410,7 +10410,7 @@ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * pa
1041010410
// using adamw param subset we care about - alpha, wd - could have a separate struct
1041110411
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
1041210412
const float alpha = sgd_params_ptr[0];
10413-
const float keep = sgd_params_ptr[1];
10413+
const float keep = 1.f - alpha * sgd_params_ptr[1];
1041410414

1041510415
for (int ir = ir0; ir < ir1; ++ir) {
1041610416
const int64_t i03 = ir / (ne02 * ne01);

ggml/src/ggml-cuda/opt-step-sgd.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ static __global__ void opt_step_sgd_f32(
1111

1212
if (i >= k)
1313
return;
14-
x[i] = x[i] * pars[1] - pars[0] * g[i];
14+
x[i] = x[i] * (1.f - pars[0] * pars[1]) - pars[0] * g[i];
1515
}
1616

1717
static void opt_step_sgd_f32_cuda(

ggml/src/ggml-opt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,8 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
824824
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
825825
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
826826
float * sgd = ggml_get_data_f32(opt_ctx->adamw_params);
827-
sgd[1] = 1. - (sgd[0] = opt_pars.sgd.alpha) * opt_pars.sgd.wd;
827+
sgd[0] = opt_paras.sgd.alpha;
828+
sgd[1] = opt_paras.sgd.wd;
828829
}
829830
break;
830831

0 commit comments

Comments
 (0)