Skip to content

Commit 31fc061

Browse files
authored
SGD optimizer stub (#139)
* Defining the SGD minimization step in the optimizer type * Add note about refactor needed * Pass optimizer instance down to layer % update() * Apply the optimizer update step in layer % update * Changes in tests and examples to account for the API change in network % update() * Make optimizer optional; default to SGD with learning rate of 1 * Apply optimizer to conv2d layer
1 parent 44833c2 commit 31fc061

14 files changed

+138
-81
lines changed

example/get_set_network_params.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
program get_set_network_params
22
use nf, only: dense, input, network
3+
use nf_optimizers, only: sgd
34
implicit none
45
type(network) :: net1, net2
56
real :: x(1), y(1)
@@ -37,7 +38,7 @@ program get_set_network_params
3738

3839
call net1 % forward(x)
3940
call net1 % backward(y)
40-
call net1 % update(1.)
41+
call net1 % update(sgd(learning_rate=1.))
4142

4243
if (mod(n, 10000) == 0) then
4344
ypred1 = [(net1 % predict([xtest(i)]), i=1, test_size)]

example/quadratic.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ program quadratic_fit
44
! descent.
55
use nf, only: dense, input, network
66
use nf_dense_layer, only: dense_layer
7+
use nf_optimizers, only: sgd
78

89
implicit none
910
type(network) :: net_sgd, net_batch_sgd, net_minibatch_sgd, net_rms_prop
@@ -97,7 +98,7 @@ subroutine sgd_optimizer(net, x, y, learning_rate, num_epochs)
9798
do i = 1, size(x)
9899
call net % forward([x(i)])
99100
call net % backward([y(i)])
100-
call net % update(learning_rate)
101+
call net % update(sgd(learning_rate=learning_rate))
101102
end do
102103
end do
103104

@@ -120,7 +121,7 @@ subroutine batch_gd_optimizer(net, x, y, learning_rate, num_epochs)
120121
call net % forward([x(i)])
121122
call net % backward([y(i)])
122123
end do
123-
call net % update(learning_rate / size(x))
124+
call net % update(sgd(learning_rate=learning_rate / size(x)))
124125
end do
125126

126127
end subroutine batch_gd_optimizer
@@ -164,7 +165,7 @@ subroutine minibatch_gd_optimizer(net, x, y, learning_rate, num_epochs, batch_si
164165
call net % backward([y(i)])
165166
end do
166167

167-
call net % update(learning_rate / batch_size)
168+
call net % update(sgd(learning_rate=learning_rate / batch_size))
168169
end do
169170
end do
170171
end subroutine minibatch_gd_optimizer

example/simple.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ program simple
2424

2525
call net % forward(x)
2626
call net % backward(y)
27-
call net % update(1.)
27+
call net % update()
2828

2929
if (mod(n, 50) == 0) &
3030
print '(i4,2(3x,f8.6))', n, net % predict(x)

example/sine.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ program sine
3131

3232
call net % forward(x)
3333
call net % backward(y)
34-
call net % update(1.)
34+
call net % update()
3535

3636
if (mod(n, 10000) == 0) then
3737
ypred = [(net % predict([xtest(i)]), i = 1, test_size)]

src/nf/nf_conv2d_layer.f90

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ module nf_conv2d_layer
3636
procedure :: get_num_params
3737
procedure :: get_params
3838
procedure :: set_params
39-
procedure :: update
4039

4140
end type conv2d_layer
4241

@@ -105,14 +104,6 @@ module subroutine set_params(self, params)
105104
!! Parameters to set
106105
end subroutine set_params
107106

108-
module subroutine update(self, learning_rate)
109-
!! Update the weights and biases.
110-
class(conv2d_layer), intent(in out) :: self
111-
!! Dense layer instance
112-
real, intent(in) :: learning_rate
113-
!! Learning rate (must be > 0)
114-
end subroutine update
115-
116107
end interface
117108

118109
end module nf_conv2d_layer

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -225,20 +225,4 @@ module subroutine set_params(self, params)
225225

226226
end subroutine set_params
227227

228-
229-
module subroutine update(self, learning_rate)
230-
class(conv2d_layer), intent(in out) :: self
231-
real, intent(in) :: learning_rate
232-
233-
! Sum weight and bias gradients across images, if any
234-
call co_sum(self % dw)
235-
call co_sum(self % db)
236-
237-
self % kernel = self % kernel - learning_rate * self % dw
238-
self % biases = self % biases - learning_rate * self % db
239-
self % dw = 0
240-
self % db = 0
241-
242-
end subroutine update
243-
244228
end submodule nf_conv2d_layer_submodule

src/nf/nf_dense_layer.f90

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ module nf_dense_layer
3737
procedure :: get_params
3838
procedure :: set_params
3939
procedure :: init
40-
procedure :: update
4140

4241
end type dense_layer
4342

@@ -115,14 +114,6 @@ module subroutine init(self, input_shape)
115114
!! Shape of the input layer
116115
end subroutine init
117116

118-
module subroutine update(self, learning_rate)
119-
!! Update the weights and biases.
120-
class(dense_layer), intent(in out) :: self
121-
!! Dense layer instance
122-
real, intent(in) :: learning_rate
123-
!! Learning rate (must be > 0)
124-
end subroutine update
125-
126117
end interface
127118

128119
end module nf_dense_layer

src/nf/nf_dense_layer_submodule.f90

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,4 @@ module subroutine init(self, input_shape)
128128

129129
end subroutine init
130130

131-
module subroutine update(self, learning_rate)
132-
class(dense_layer), intent(in out) :: self
133-
real, intent(in) :: learning_rate
134-
135-
! Sum weight and bias gradients across images, if any
136-
call co_sum(self % dw)
137-
call co_sum(self % db)
138-
139-
self % weights = self % weights - learning_rate * self % dw
140-
self % biases = self % biases - learning_rate * self % db
141-
self % dw = 0
142-
self % db = 0
143-
144-
end subroutine update
145-
146131
end submodule nf_dense_layer_submodule

src/nf/nf_layer.f90

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module nf_layer
44
!! user-facing API.
55

66
use nf_base_layer, only: base_layer
7+
use nf_optimizers, only: optimizer_base_type
78

89
implicit none
910

@@ -144,16 +145,18 @@ module subroutine set_params(self, params)
144145
!! Parameters of this layer
145146
end subroutine set_params
146147

147-
impure elemental module subroutine update(self, learning_rate)
148+
impure elemental module subroutine update(self, optimizer, batch_size)
148149
!! Update the weights and biases on the layer using the stored
149150
!! gradients (from backward passes), and flush those same stored
150151
!! gradients to zero.
151152
!! This changes the state of the layer.
152153
!! Typically used only internally from the `network % update` method.
153154
class(layer), intent(in out) :: self
154155
!! Layer instance
155-
real, intent(in) :: learning_rate
156-
!! Learning rate to use; must be > 0.
156+
class(optimizer_base_type), intent(in) :: optimizer
157+
!! Optimizer instance to use
158+
integer, intent(in), optional :: batch_size
159+
!! Batch size (default 1)
157160
end subroutine update
158161

159162
end interface

src/nf/nf_layer_submodule.f90

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use nf_input3d_layer, only: input3d_layer
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
11+
use nf_optimizers, only: optimizer_base_type
1112

1213
contains
1314

@@ -382,15 +383,54 @@ module subroutine set_params(self, params)
382383
end subroutine set_params
383384

384385

385-
impure elemental module subroutine update(self, learning_rate)
386+
impure elemental module subroutine update(self, optimizer, batch_size)
386387
class(layer), intent(in out) :: self
387-
real, intent(in) :: learning_rate
388+
class(optimizer_base_type), intent(in) :: optimizer
389+
integer, intent(in), optional :: batch_size
390+
integer :: batch_size_
391+
392+
batch_size_ = 1
393+
if (present(batch_size)) batch_size_ = batch_size
394+
395+
select type (this_layer => self % p)
396+
type is (dense_layer)
397+
398+
! Sum weight and bias gradients across images, if any
399+
call co_sum(this_layer % dw)
400+
call co_sum(this_layer % db)
401+
402+
call optimizer % minimize( &
403+
this_layer % weights, &
404+
this_layer % dw / batch_size_ &
405+
)
406+
call optimizer % minimize( &
407+
this_layer % biases, &
408+
this_layer % db / batch_size_ &
409+
)
410+
411+
! Reset gradients.
412+
this_layer % dw = 0
413+
this_layer % db = 0
414+
415+
type is (conv2d_layer)
416+
417+
! Sum weight and bias gradients across images, if any
418+
call co_sum(this_layer % dw)
419+
call co_sum(this_layer % db)
420+
421+
call optimizer % minimize( &
422+
this_layer % kernel, &
423+
this_layer % dw / batch_size_ &
424+
)
425+
call optimizer % minimize( &
426+
this_layer % biases, &
427+
this_layer % db / batch_size_ &
428+
)
429+
430+
! Reset gradients.
431+
this_layer % dw = 0
432+
this_layer % db = 0
388433

389-
select type(this_layer => self % p)
390-
type is(dense_layer)
391-
call this_layer % update(learning_rate)
392-
type is(conv2d_layer)
393-
call this_layer % update(learning_rate)
394434
end select
395435

396436
end subroutine update

0 commit comments

Comments
 (0)