Skip to content

Commit 38896cc

Browse files
committed
WIP optimizer optimization
1 parent bf1478e commit 38896cc

File tree

4 files changed

+164
-59
lines changed

4 files changed

+164
-59
lines changed

src/nf/nf_dense_layer.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ module nf_dense_layer
3434
procedure :: backward
3535
procedure :: forward
3636
procedure :: get_gradients
37+
procedure :: get_gradients_ptr
3738
procedure :: get_num_params
3839
procedure :: get_params
3940
procedure :: get_params_ptr
@@ -112,6 +113,12 @@ module function get_gradients(self) result(gradients)
112113
!! Gradients of this layer
113114
end function get_gradients
114115

116+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
117+
class(dense_layer), intent(in), target :: self
118+
real, pointer :: dw_ptr(:,:)
119+
real, pointer :: db_ptr(:)
120+
end subroutine get_gradients_ptr
121+
115122
module subroutine set_params(self, params)
116123
!! Set the parameters of this layer.
117124
!! The parameters are ordered as weights first, biases second.

src/nf/nf_dense_layer_submodule.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ module function get_gradients(self) result(gradients)
102102
end function get_gradients
103103

104104

105+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
106+
class(dense_layer), intent(in), target :: self
107+
real, pointer :: dw_ptr(:,:)
108+
real, pointer :: db_ptr(:)
109+
dw_ptr => self % dw
110+
db_ptr => self % db
111+
end subroutine get_gradients_ptr
112+
113+
105114
module subroutine set_params(self, params)
106115
class(dense_layer), intent(in out) :: self
107116
real, intent(in), target :: params(:)

src/nf/nf_network_submodule.f90

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ module subroutine update(self, optimizer, batch_size)
649649
integer, intent(in), optional :: batch_size
650650
integer :: batch_size_
651651
real, allocatable :: params(:)
652-
real, pointer :: weights(:), biases(:), gradient(:)
652+
real, pointer :: weights(:,:), biases(:), dw(:,:), db(:)
653653
integer :: n
654654

655655
! Passing the optimizer instance is optional. If not provided, and if the
@@ -702,7 +702,9 @@ module subroutine update(self, optimizer, batch_size)
702702
select type(this_layer => self % layers(n) % p)
703703
type is(dense_layer)
704704
call this_layer % get_params_ptr(weights, biases)
705-
call self % optimizer % minimize(weights, biases, self % get_gradients() / batch_size_)
705+
call this_layer % get_gradients_ptr(dw, db)
706+
call self % optimizer % minimize(weights, dw / batch_size_)
707+
call self % optimizer % minimize(biases, db / batch_size_)
706708
!call this_layer % set_params(weights, biases)
707709
end select
708710
end do

src/nf/nf_optimizers.f90

Lines changed: 144 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ module nf_optimizers
1919
real :: learning_rate = 0.01
2020
contains
2121
procedure(init), deferred :: init
22-
procedure(minimize), deferred :: minimize
22+
procedure(minimize_1d), deferred :: minimize_1d
23+
procedure(minimize_2d), deferred :: minimize_2d
24+
generic :: minimize => minimize_1d, minimize_2d
2325
end type optimizer_base_type
2426

2527
abstract interface
@@ -30,13 +32,19 @@ impure elemental subroutine init(self, num_params)
3032
integer, intent(in) :: num_params
3133
end subroutine init
3234

33-
pure subroutine minimize(self, weights, biases, gradient)
35+
pure subroutine minimize_1d(self, param, gradient)
3436
import :: optimizer_base_type
3537
class(optimizer_base_type), intent(inout) :: self
36-
real, intent(inout), pointer :: weights(:)
37-
real, intent(inout), pointer :: biases(:)
38-
real, intent(in), pointer :: gradient(:)
39-
end subroutine minimize
38+
real, intent(inout) :: param(:)
39+
real, intent(in) :: gradient(:)
40+
end subroutine minimize_1d
41+
42+
pure subroutine minimize_2d(self, param, gradient)
43+
import :: optimizer_base_type
44+
class(optimizer_base_type), intent(inout) :: self
45+
real, intent(inout) :: param(:,:)
46+
real, intent(in) :: gradient(:,:)
47+
end subroutine minimize_2d
4048

4149
end interface
4250

@@ -47,7 +55,8 @@ end subroutine minimize
4755
real, allocatable, private :: velocity(:)
4856
contains
4957
procedure :: init => init_sgd
50-
procedure :: minimize => minimize_sgd
58+
procedure :: minimize_1d => minimize_sgd_1d
59+
procedure :: minimize_2d => minimize_sgd_2d
5160
end type sgd
5261

5362
type, extends(optimizer_base_type) :: rmsprop
@@ -62,7 +71,8 @@ end subroutine minimize
6271
real, allocatable, private :: rms_gradient(:)
6372
contains
6473
procedure :: init => init_rmsprop
65-
procedure :: minimize => minimize_rmsprop
74+
procedure :: minimize_1d => minimize_rmsprop_1d
75+
procedure :: minimize_2d => minimize_rmsprop_2d
6676
end type rmsprop
6777

6878
type, extends(optimizer_base_type) :: adam
@@ -85,7 +95,8 @@ end subroutine minimize
8595
integer, private :: t = 0
8696
contains
8797
procedure :: init => init_adam
88-
procedure :: minimize => minimize_adam
98+
procedure :: minimize_1d => minimize_adam_1d
99+
procedure :: minimize_2d => minimize_adam_2d
89100
end type adam
90101

91102
type, extends(optimizer_base_type) :: adagrad
@@ -102,7 +113,8 @@ end subroutine minimize
102113
integer, private :: t = 0
103114
contains
104115
procedure :: init => init_adagrad
105-
procedure :: minimize => minimize_adagrad
116+
procedure :: minimize_1d => minimize_adagrad_1d
117+
procedure :: minimize_2d => minimize_adagrad_2d
106118
end type adagrad
107119

108120
contains
@@ -117,35 +129,30 @@ impure elemental subroutine init_sgd(self, num_params)
117129
end subroutine init_sgd
118130

119131

120-
pure subroutine minimize_sgd(self, weights, biases, gradient)
132+
pure subroutine minimize_sgd_1d(self, param, gradient)
121133
!! Concrete implementation of a stochastic gradient descent optimizer
122134
!! update rule.
123135
class(sgd), intent(inout) :: self
124-
real, intent(inout), pointer :: weights(:)
125-
real, intent(inout), pointer :: biases(:)
126-
real, intent(in), pointer :: gradient(:)
136+
real, intent(inout) :: param(:)
137+
real, intent(in) :: gradient(:)
127138

128139
if (self % momentum > 0) then
129140
! Apply momentum update
130141
self % velocity = self % momentum * self % velocity &
131142
- self % learning_rate * gradient
132143
if (self % nesterov) then
133144
! Apply Nesterov update
134-
weights = weights + self % momentum * self % velocity &
135-
- self % learning_rate * gradient
136-
biases = biases + self % momentum * self % velocity &
145+
param = param + self % momentum * self % velocity &
137146
- self % learning_rate * gradient
138147
else
139-
weights = weights + self % velocity
140-
biases = biases + self % velocity
148+
param = param + self % velocity
141149
end if
142150
else
143151
! Apply regular update
144-
weights = weights - self % learning_rate * gradient
145-
biases = biases - self % learning_rate * gradient
152+
param = param - self % learning_rate * gradient
146153
end if
147154

148-
end subroutine minimize_sgd
155+
end subroutine minimize_sgd_1d
149156

150157

151158
impure elemental subroutine init_rmsprop(self, num_params)
@@ -158,24 +165,21 @@ impure elemental subroutine init_rmsprop(self, num_params)
158165
end subroutine init_rmsprop
159166

160167

161-
pure subroutine minimize_rmsprop(self, weights, biases, gradient)
168+
pure subroutine minimize_rmsprop_1d(self, param, gradient)
162169
!! Concrete implementation of a RMSProp optimizer update rule.
163170
class(rmsprop), intent(inout) :: self
164-
real, intent(inout), pointer :: weights(:)
165-
real, intent(inout), pointer :: biases(:)
166-
real, intent(in), pointer :: gradient(:)
171+
real, intent(inout) :: param(:)
172+
real, intent(in) :: gradient(:)
167173

168174
! Compute the RMS of the gradient using the RMSProp rule
169175
self % rms_gradient = self % decay_rate * self % rms_gradient &
170176
+ (1 - self % decay_rate) * gradient**2
171177

172178
! Update the network parameters based on the new RMS of the gradient
173-
weights = weights - self % learning_rate &
174-
/ sqrt(self % rms_gradient + self % epsilon) * gradient
175-
biases = biases - self % learning_rate &
179+
param = param - self % learning_rate &
176180
/ sqrt(self % rms_gradient + self % epsilon) * gradient
177181

178-
end subroutine minimize_rmsprop
182+
end subroutine minimize_rmsprop_1d
179183

180184

181185
impure elemental subroutine init_adam(self, num_params)
@@ -189,18 +193,17 @@ impure elemental subroutine init_adam(self, num_params)
189193
end subroutine init_adam
190194

191195

192-
pure subroutine minimize_adam(self, weights, biases, gradient)
196+
pure subroutine minimize_adam_1d(self, param, gradient)
193197
!! Concrete implementation of an Adam optimizer update rule.
194198
class(adam), intent(inout) :: self
195-
real, intent(inout), pointer :: weights(:)
196-
real, intent(inout), pointer :: biases(:)
197-
real, intent(in), pointer :: gradient(:)
199+
real, intent(inout) :: param(:)
200+
real, intent(in) :: gradient(:)
198201

199202
self % t = self % t + 1
200203

201204
! If weight_decay_l2 > 0, use L2 regularization;
202205
! otherwise, default to regular Adam.
203-
associate(g => gradient + self % weight_decay_l2 * weights)
206+
associate(g => gradient + self % weight_decay_l2 * param)
204207
self % m = self % beta1 * self % m + (1 - self % beta1) * g
205208
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
206209
end associate
@@ -212,19 +215,13 @@ pure subroutine minimize_adam(self, weights, biases, gradient)
212215
)
213216

214217
! Update parameters.
215-
weights = weights &
218+
param = param &
216219
- self % learning_rate * (m_hat / (sqrt(v_hat) + self % epsilon) &
217-
+ self % weight_decay_decoupled * weights)
218-
219-
! Update biases (without weight decay for biases)
220-
associate(g => gradient)
221-
biases = biases &
222-
- self % learning_rate * (m_hat / (sqrt(v_hat) + self % epsilon))
223-
end associate
220+
+ self % weight_decay_decoupled * param)
224221

225222
end associate
226223

227-
end subroutine minimize_adam
224+
end subroutine minimize_adam_1d
228225

229226

230227
impure elemental subroutine init_adagrad(self, num_params)
@@ -237,43 +234,133 @@ impure elemental subroutine init_adagrad(self, num_params)
237234
end subroutine init_adagrad
238235

239236

240-
pure subroutine minimize_adagrad(self, weights, biases, gradient)
237+
pure subroutine minimize_adagrad_1d(self, param, gradient)
241238
!! Concrete implementation of an Adagrad optimizer update rule.
242239
class(adagrad), intent(inout) :: self
243-
real, intent(inout), pointer :: weights(:)
244-
real, intent(inout), pointer :: biases(:)
245-
real, intent(in), pointer :: gradient(:)
240+
real, intent(inout) :: param(:)
241+
real, intent(in) :: gradient(:)
246242

247243
! Update the current time step
248244
self % t = self % t + 1
249245

250-
! For weights
251246
associate( &
252247
! If weight_decay_l2 > 0, use L2 regularization;
253248
! otherwise, default to regular Adagrad.
254-
g => gradient + self % weight_decay_l2 * weights, &
249+
g => gradient + self % weight_decay_l2 * param, &
255250
! Amortize the learning rate as function of the current time step.
256251
learning_rate => self % learning_rate &
257252
/ (1 + (self % t - 1) * self % learning_rate_decay) &
258253
)
259254

260255
self % sum_squared_gradient = self % sum_squared_gradient + g**2
261256

262-
weights = weights - learning_rate * g / (sqrt(self % sum_squared_gradient) &
257+
param = param - learning_rate * g / (sqrt(self % sum_squared_gradient) &
263258
+ self % epsilon)
264259

265260
end associate
266-
267-
! For biases (without weight decay)
261+
262+
end subroutine minimize_adagrad_1d
263+
264+
265+
pure subroutine minimize_sgd_2d(self, param, gradient)
266+
!! Concrete implementation of a stochastic gradient descent optimizer
267+
!! update rule for 2D arrays.
268+
class(sgd), intent(inout) :: self
269+
real, intent(inout) :: param(:,:)
270+
real, intent(in) :: gradient(:,:)
271+
272+
if (self % momentum > 0) then
273+
! Apply momentum update
274+
self % velocity = self % momentum * self % velocity &
275+
- self % learning_rate * reshape(gradient, [size(gradient)])
276+
if (self % nesterov) then
277+
! Apply Nesterov update
278+
param = param + reshape(self % momentum * self % velocity &
279+
- self % learning_rate * reshape(gradient, [size(gradient)]), shape(param))
280+
else
281+
param = param + reshape(self % velocity, shape(param))
282+
end if
283+
else
284+
! Apply regular update
285+
param = param - self % learning_rate * gradient
286+
end if
287+
288+
end subroutine minimize_sgd_2d
289+
290+
291+
pure subroutine minimize_rmsprop_2d(self, param, gradient)
292+
!! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
293+
class(rmsprop), intent(inout) :: self
294+
real, intent(inout) :: param(:,:)
295+
real, intent(in) :: gradient(:,:)
296+
297+
! Compute the RMS of the gradient using the RMSProp rule
298+
self % rms_gradient = self % decay_rate * self % rms_gradient &
299+
+ (1 - self % decay_rate) * reshape(gradient, [size(gradient)])**2
300+
301+
! Update the network parameters based on the new RMS of the gradient
302+
param = param - self % learning_rate &
303+
/ sqrt(reshape(self % rms_gradient, shape(param)) + self % epsilon) * gradient
304+
305+
end subroutine minimize_rmsprop_2d
306+
307+
308+
pure subroutine minimize_adam_2d(self, param, gradient)
309+
!! Concrete implementation of an Adam optimizer update rule for 2D arrays.
310+
class(adam), intent(inout) :: self
311+
real, intent(inout) :: param(:,:)
312+
real, intent(in) :: gradient(:,:)
313+
314+
self % t = self % t + 1
315+
316+
! If weight_decay_l2 > 0, use L2 regularization;
317+
! otherwise, default to regular Adam.
318+
associate(g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]))
319+
self % m = self % beta1 * self % m + (1 - self % beta1) * g
320+
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
321+
end associate
322+
323+
! Compute bias-corrected first and second moment estimates.
324+
associate( &
325+
m_hat => self % m / (1 - self % beta1**self % t), &
326+
v_hat => self % v / (1 - self % beta2**self % t) &
327+
)
328+
329+
! Update parameters.
330+
param = param &
331+
- self % learning_rate * reshape(m_hat / (sqrt(v_hat) + self % epsilon), shape(param)) &
332+
- self % learning_rate * self % weight_decay_decoupled * param
333+
334+
end associate
335+
336+
end subroutine minimize_adam_2d
337+
338+
339+
pure subroutine minimize_adagrad_2d(self, param, gradient)
340+
!! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341+
class(adagrad), intent(inout) :: self
342+
real, intent(inout) :: param(:,:)
343+
real, intent(in) :: gradient(:,:)
344+
345+
! Update the current time step
346+
self % t = self % t + 1
347+
268348
associate( &
269-
g => gradient, &
349+
! If weight_decay_l2 > 0, use L2 regularization;
350+
! otherwise, default to regular Adagrad.
351+
g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]), &
352+
! Amortize the learning rate as function of the current time step.
270353
learning_rate => self % learning_rate &
271354
/ (1 + (self % t - 1) * self % learning_rate_decay) &
272355
)
273-
biases = biases - learning_rate * g / (sqrt(self % sum_squared_gradient) &
274-
+ self % epsilon)
356+
357+
self % sum_squared_gradient = self % sum_squared_gradient + g**2
358+
359+
param = param - learning_rate * reshape(g / (sqrt(self % sum_squared_gradient) &
360+
+ self % epsilon), shape(param))
361+
275362
end associate
276363

277-
end subroutine minimize_adagrad
364+
end subroutine minimize_adagrad_2d
278365

279366
end module nf_optimizers

0 commit comments

Comments
 (0)