@@ -30,12 +30,16 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
30
30
if (lr_policy == " fixed" ) {
31
31
rate = this ->param_ .base_lr ();
32
32
} else if (lr_policy == " step" ) {
33
+ CHECK_GT (this ->param_ .stepsize (), 0 );
33
34
this ->current_step_ = this ->iter_ / this ->param_ .stepsize ();
35
+ CHECK_GE (this ->param_ .gamma (), 0 );
34
36
rate = this ->param_ .base_lr () *
35
37
pow (this ->param_ .gamma (), this ->current_step_ );
36
38
} else if (lr_policy == " exp" ) {
39
+ CHECK_GE (this ->param_ .gamma (), 0 );
37
40
rate = this ->param_ .base_lr () * pow (this ->param_ .gamma (), this ->iter_ );
38
41
} else if (lr_policy == " inv" ) {
42
+ CHECK_GE (this ->param_ .gamma (), 0 );
39
43
rate = this ->param_ .base_lr () *
40
44
pow (Dtype (1 ) + this ->param_ .gamma () * this ->iter_ ,
41
45
- this ->param_ .power ());
@@ -46,13 +50,16 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
46
50
LOG (INFO) << " MultiStep Status: Iteration " <<
47
51
this ->iter_ << " , step = " << this ->current_step_ ;
48
52
}
53
+ CHECK_GE (this ->param_ .gamma (), 0 );
49
54
rate = this ->param_ .base_lr () *
50
55
pow (this ->param_ .gamma (), this ->current_step_ );
51
56
} else if (lr_policy == " poly" ) {
52
57
rate = this ->param_ .base_lr () * pow (Dtype (1 .) -
53
58
(Dtype (this ->iter_ ) / Dtype (this ->param_ .max_iter ())),
54
59
this ->param_ .power ());
55
60
} else if (lr_policy == " sigmoid" ) {
61
+ CHECK_GE (this ->param_ .gamma (), 0 );
62
+ CHECK_GT (this ->param_ .stepsize (), 0 );
56
63
rate = this ->param_ .base_lr () * (Dtype (1 .) /
57
64
(Dtype (1 .) + exp (-this ->param_ .gamma () * (Dtype (this ->iter_ ) -
58
65
Dtype (this ->param_ .stepsize ())))));
0 commit comments