Skip to content

Commit bf1478e

Browse files
committed
WIP optimizer refactor w/ pointers
1 parent 402b84a commit bf1478e

File tree

4 files changed

+82
-27
lines changed

4 files changed

+82
-27
lines changed

src/nf/nf_dense_layer.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ module nf_dense_layer
3636
procedure :: get_gradients
3737
procedure :: get_num_params
3838
procedure :: get_params
39+
procedure :: get_params_ptr
3940
procedure :: init
4041
procedure :: set_params
4142

@@ -96,6 +97,12 @@ module function get_params(self) result(params)
9697
!! Parameters of this layer
9798
end function get_params
9899

100+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
101+
class(dense_layer), intent(in), target :: self
102+
real, pointer :: w_ptr(:,:)
103+
real, pointer :: b_ptr(:)
104+
end subroutine get_params_ptr
105+
99106
module function get_gradients(self) result(gradients)
100107
!! Return the gradients of this layer.
101108
!! The gradients are ordered as weights first, biases second.

src/nf/nf_dense_layer_submodule.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ module function get_params(self) result(params)
7777
end function get_params
7878

7979

80+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
81+
class(dense_layer), intent(in), target :: self
82+
real, pointer :: w_ptr(:,:)
83+
real, pointer :: b_ptr(:)
84+
w_ptr => self % weights
85+
b_ptr => self % biases
86+
end subroutine get_params_ptr
87+
88+
8089
module function get_gradients(self) result(gradients)
8190
class(dense_layer), intent(in), target :: self
8291
real, allocatable :: gradients(:)

src/nf/nf_network_submodule.f90

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ module subroutine update(self, optimizer, batch_size)
649649
integer, intent(in), optional :: batch_size
650650
integer :: batch_size_
651651
real, allocatable :: params(:)
652+
real, pointer :: weights(:), biases(:), gradient(:)
652653
integer :: n
653654

654655
! Passing the optimizer instance is optional. If not provided, and if the
@@ -693,9 +694,19 @@ module subroutine update(self, optimizer, batch_size)
693694
end do
694695
#endif
695696

696-
params = self % get_params()
697-
call self % optimizer % minimize(params, self % get_gradients() / batch_size_)
698-
call self % set_params(params)
697+
!params = self % get_params()
698+
!call self % optimizer % minimize(params, self % get_gradients() / batch_size_)
699+
!call self % set_params(params)
700+
701+
do n = 2, size(self % layers)
702+
select type(this_layer => self % layers(n) % p)
703+
type is(dense_layer)
704+
call this_layer % get_params_ptr(weights, biases)
705+
call self % optimizer % minimize(weights, biases, self % get_gradients() / batch_size_)
706+
!call this_layer % set_params(weights, biases)
707+
end select
708+
end do
709+
699710

700711
! Flush network gradients to zero.
701712
do n = 2, size(self % layers)

src/nf/nf_optimizers.f90

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ impure elemental subroutine init(self, num_params)
3030
integer, intent(in) :: num_params
3131
end subroutine init
3232

33-
pure subroutine minimize(self, param, gradient)
33+
pure subroutine minimize(self, weights, biases, gradient)
3434
import :: optimizer_base_type
3535
class(optimizer_base_type), intent(inout) :: self
36-
real, intent(inout) :: param(:)
37-
real, intent(in) :: gradient(:)
36+
real, intent(inout), pointer :: weights(:)
37+
real, intent(inout), pointer :: biases(:)
38+
real, intent(in), pointer :: gradient(:)
3839
end subroutine minimize
3940

4041
end interface
@@ -116,27 +117,32 @@ impure elemental subroutine init_sgd(self, num_params)
116117
end subroutine init_sgd
117118

118119

119-
pure subroutine minimize_sgd(self, param, gradient)
120+
pure subroutine minimize_sgd(self, weights, biases, gradient)
120121
!! Concrete implementation of a stochastic gradient descent optimizer
121122
!! update rule.
122123
class(sgd), intent(inout) :: self
123-
real, intent(inout) :: param(:)
124-
real, intent(in) :: gradient(:)
124+
real, intent(inout), pointer :: weights(:)
125+
real, intent(inout), pointer :: biases(:)
126+
real, intent(in), pointer :: gradient(:)
125127

126128
if (self % momentum > 0) then
127129
! Apply momentum update
128130
self % velocity = self % momentum * self % velocity &
129131
- self % learning_rate * gradient
130132
if (self % nesterov) then
131133
! Apply Nesterov update
132-
param = param + self % momentum * self % velocity &
134+
weights = weights + self % momentum * self % velocity &
135+
- self % learning_rate * gradient
136+
biases = biases + self % momentum * self % velocity &
133137
- self % learning_rate * gradient
134138
else
135-
param = param + self % velocity
139+
weights = weights + self % velocity
140+
biases = biases + self % velocity
136141
end if
137142
else
138143
! Apply regular update
139-
param = param - self % learning_rate * gradient
144+
weights = weights - self % learning_rate * gradient
145+
biases = biases - self % learning_rate * gradient
140146
end if
141147

142148
end subroutine minimize_sgd
@@ -152,18 +158,21 @@ impure elemental subroutine init_rmsprop(self, num_params)
152158
end subroutine init_rmsprop
153159

154160

155-
pure subroutine minimize_rmsprop(self, param, gradient)
161+
pure subroutine minimize_rmsprop(self, weights, biases, gradient)
156162
!! Concrete implementation of a RMSProp optimizer update rule.
157163
class(rmsprop), intent(inout) :: self
158-
real, intent(inout) :: param(:)
159-
real, intent(in) :: gradient(:)
164+
real, intent(inout), pointer :: weights(:)
165+
real, intent(inout), pointer :: biases(:)
166+
real, intent(in), pointer :: gradient(:)
160167

161168
! Compute the RMS of the gradient using the RMSProp rule
162169
self % rms_gradient = self % decay_rate * self % rms_gradient &
163170
+ (1 - self % decay_rate) * gradient**2
164171

165172
! Update the network parameters based on the new RMS of the gradient
166-
param = param - self % learning_rate &
173+
weights = weights - self % learning_rate &
174+
/ sqrt(self % rms_gradient + self % epsilon) * gradient
175+
biases = biases - self % learning_rate &
167176
/ sqrt(self % rms_gradient + self % epsilon) * gradient
168177

169178
end subroutine minimize_rmsprop
@@ -180,17 +189,18 @@ impure elemental subroutine init_adam(self, num_params)
180189
end subroutine init_adam
181190

182191

183-
pure subroutine minimize_adam(self, param, gradient)
192+
pure subroutine minimize_adam(self, weights, biases, gradient)
184193
!! Concrete implementation of an Adam optimizer update rule.
185194
class(adam), intent(inout) :: self
186-
real, intent(inout) :: param(:)
187-
real, intent(in) :: gradient(:)
195+
real, intent(inout), pointer :: weights(:)
196+
real, intent(inout), pointer :: biases(:)
197+
real, intent(in), pointer :: gradient(:)
188198

189199
self % t = self % t + 1
190200

191201
! If weight_decay_l2 > 0, use L2 regularization;
192202
! otherwise, default to regular Adam.
193-
associate(g => gradient + self % weight_decay_l2 * param)
203+
associate(g => gradient + self % weight_decay_l2 * weights)
194204
self % m = self % beta1 * self % m + (1 - self % beta1) * g
195205
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
196206
end associate
@@ -202,9 +212,15 @@ pure subroutine minimize_adam(self, param, gradient)
202212
)
203213

204214
! Update parameters.
205-
param = param &
215+
weights = weights &
206216
- self % learning_rate * (m_hat / (sqrt(v_hat) + self % epsilon) &
207-
+ self % weight_decay_decoupled * param)
217+
+ self % weight_decay_decoupled * weights)
218+
219+
! Update biases (without weight decay for biases)
220+
associate(g => gradient)
221+
biases = biases &
222+
- self % learning_rate * (m_hat / (sqrt(v_hat) + self % epsilon))
223+
end associate
208224

209225
end associate
210226

@@ -221,30 +237,42 @@ impure elemental subroutine init_adagrad(self, num_params)
221237
end subroutine init_adagrad
222238

223239

224-
pure subroutine minimize_adagrad(self, param, gradient)
240+
pure subroutine minimize_adagrad(self, weights, biases, gradient)
225241
!! Concrete implementation of an Adagrad optimizer update rule.
226242
class(adagrad), intent(inout) :: self
227-
real, intent(inout) :: param(:)
228-
real, intent(in) :: gradient(:)
243+
real, intent(inout), pointer :: weights(:)
244+
real, intent(inout), pointer :: biases(:)
245+
real, intent(in), pointer :: gradient(:)
229246

230247
! Update the current time step
231248
self % t = self % t + 1
232249

250+
! For weights
233251
associate( &
234252
! If weight_decay_l2 > 0, use L2 regularization;
235253
! otherwise, default to regular Adagrad.
236-
g => gradient + self % weight_decay_l2 * param, &
254+
g => gradient + self % weight_decay_l2 * weights, &
237255
! Amortize the learning rate as function of the current time step.
238256
learning_rate => self % learning_rate &
239257
/ (1 + (self % t - 1) * self % learning_rate_decay) &
240258
)
241259

242260
self % sum_squared_gradient = self % sum_squared_gradient + g**2
243261

244-
param = param - learning_rate * g / (sqrt(self % sum_squared_gradient) &
262+
weights = weights - learning_rate * g / (sqrt(self % sum_squared_gradient) &
245263
+ self % epsilon)
246264

247265
end associate
266+
267+
! For biases (without weight decay)
268+
associate( &
269+
g => gradient, &
270+
learning_rate => self % learning_rate &
271+
/ (1 + (self % t - 1) * self % learning_rate_decay) &
272+
)
273+
biases = biases - learning_rate * g / (sqrt(self % sum_squared_gradient) &
274+
+ self % epsilon)
275+
end associate
248276

249277
end subroutine minimize_adagrad
250278

0 commit comments

Comments
 (0)