Skip to content

Commit 8fb2ef3

Browse files
milancurcicjvo203
andauthored
Get set network parameters (#111)
* Add an exampple to get/set network parameters * Bump version * Add get_num_params, get_params, and set_params implementations Co-authored-by: Christopher Zapart <[email protected]> * Make get_params() a function; Make set_params() a subroutine; Co-authored-by: Christopher Zapart <[email protected]> * Tidy up example * Make layer % get_num_parameters() elemental * Begin test suite for getting and setting network params * Simplify network % get_num_params() * Warn on stderr if the user attempts to set_params to a non-zero param layer * Skip no-op layer % set_params() calls * Test getting and setting parameters * Check that the size of parameters match those of the layer * Print number of parameters in layer % print_info() * Tidy up the exampel Co-authored-by: Christopher Zapart <[email protected]>
1 parent 1b0646a commit 8fb2ef3

14 files changed

+530
-1
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ examples, in increasing level of complexity:
207207
dense model from a Keras HDF5 file and running the inference.
208208
6. [cnn_from_keras](example/cnn_from_keras.f90): Creating a pre-trained
209209
convolutional model from a Keras HDF5 file and running the inference.
210+
7. [get_set_network_params](example/get_set_network_params.f90): Getting and
211+
setting hyperparameters of a network.
210212

211213
The examples also show you the extent of the public API that's meant to be
212214
used in applications, i.e. anything from the `nf` module.

example/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ foreach(execid
33
cnn_from_keras
44
dense_mnist
55
dense_from_keras
6+
get_set_network_params
67
simple
78
sine
89
)

example/get_set_network_params.f90

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
program get_set_network_params
2+
use nf, only: dense, input, network
3+
implicit none
4+
type(network) :: net1, net2
5+
real :: x(1), y(1)
6+
real, parameter :: pi = 4 * atan(1.)
7+
integer, parameter :: num_iterations = 100000
8+
integer, parameter :: test_size = 30
9+
real :: xtest(test_size), ytest(test_size)
10+
real :: ypred1(test_size), ypred2(test_size)
11+
integer :: i, n, nparam
12+
real, allocatable :: parameters(:)
13+
14+
print '("Getting and setting network parameters")'
15+
print '(60("="))'
16+
print *
17+
print '(a)', 'First, let''s instantiate small dense network net1'
18+
print '(a)', 'of shape (1,5,1) and fit it to a sine function:'
19+
print *
20+
21+
net1 = network([ &
22+
input(1), &
23+
dense(5), &
24+
dense(1) &
25+
])
26+
27+
call net1 % print_info()
28+
29+
xtest = [((i - 1) * 2 * pi / test_size, i=1, test_size)]
30+
ytest = (sin(xtest) + 1) / 2
31+
32+
do n = 0, num_iterations
33+
34+
call random_number(x)
35+
x = x * 2 * pi
36+
y = (sin(x) + 1) / 2
37+
38+
call net1 % forward(x)
39+
call net1 % backward(y)
40+
call net1 % update(1.)
41+
42+
if (mod(n, 10000) == 0) then
43+
ypred1 = [(net1 % predict([xtest(i)]), i=1, test_size)]
44+
print '(a,i0,1x,f9.6)', 'Number of iterations, loss: ', &
45+
n, sum((ypred1 - ytest)**2) / size(ypred1)
46+
end if
47+
48+
end do
49+
50+
print *
51+
print '(a)', 'Now, let''s see how many network parameters there are'
52+
print '(a)', 'by printing the result of net1 % get_num_params():'
53+
print *
54+
print '("net1 % get_num_params() = ", i0)', net1 % get_num_params()
55+
print *
56+
print '(a)', 'We can see the values of the network parameters'
57+
print '(a)', 'by printing the result of net1 % get_params():'
58+
print *
59+
print '("net1 % get_params() = ", *(g0,1x))', net1 % get_params()
60+
print *
61+
print '(a)', 'Now, let''s create another network of the same shape and set'
62+
print '(a)', 'the parameters from the original network to it'
63+
print '(a)', 'by calling call net2 % set_params(net1 % get_params()):'
64+
65+
net2 = network([ &
66+
input(1), &
67+
dense(5), &
68+
dense(1) &
69+
])
70+
71+
! Set the parameters of net1 to net2
72+
call net2 % set_params(net1 % get_params())
73+
74+
print *
75+
print '(a)', 'We can check that the second network now has the same'
76+
print '(a)', 'parameters as net1:'
77+
print *
78+
print '("net2 % get_params() = ", *(g0,1x))', net2 % get_params()
79+
80+
ypred1 = [(net1 % predict([xtest(i)]), i=1, test_size)]
81+
ypred2 = [(net2 % predict([xtest(i)]), i=1, test_size)]
82+
83+
print *
84+
print '(a)', 'We can also check that the two networks produce the same output:'
85+
print *
86+
print '("net1 output: ", *(g0,1x))', ypred1
87+
print '("net2 output: ", *(g0,1x))', ypred2
88+
89+
print *
90+
print '(a)', 'Original and cloned network outputs match:', all(ypred1 == ypred2)
91+
92+
end program get_set_network_params

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.9.0"
2+
version = "0.10.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf/nf_conv2d_layer.f90

+27
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ module nf_conv2d_layer
3636
procedure :: init
3737
procedure :: forward
3838
procedure :: backward
39+
procedure :: get_num_params
40+
procedure :: get_params
41+
procedure :: set_params
3942
procedure :: set_activation
4043
procedure :: update
4144

@@ -82,6 +85,30 @@ pure module subroutine backward(self, input, gradient)
8285
!! Gradient (next layer)
8386
end subroutine backward
8487

88+
pure module function get_num_params(self) result(num_params)
89+
!! Get the number of parameters in the layer.
90+
class(conv2d_layer), intent(in) :: self
91+
!! A `conv2d_layer` instance
92+
integer :: num_params
93+
!! Number of parameters
94+
end function get_num_params
95+
96+
pure module function get_params(self) result(params)
97+
!! Get the parameters of the layer.
98+
class(conv2d_layer), intent(in) :: self
99+
!! A `conv2d_layer` instance
100+
real, allocatable :: params(:)
101+
!! Parameters to get
102+
end function get_params
103+
104+
module subroutine set_params(self, params)
105+
!! Set the parameters of the layer.
106+
class(conv2d_layer), intent(in out) :: self
107+
!! A `conv2d_layer` instance
108+
real, intent(in) :: params(:)
109+
!! Parameters to set
110+
end subroutine set_params
111+
85112
elemental module subroutine set_activation(self, activation)
86113
!! Set the activation functions.
87114
class(conv2d_layer), intent(in out) :: self

src/nf/nf_conv2d_layer_submodule.f90

+44
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,50 @@ pure module subroutine backward(self, input, gradient)
188188

189189
end subroutine backward
190190

191+
192+
pure module function get_num_params(self) result(num_params)
193+
class(conv2d_layer), intent(in) :: self
194+
integer :: num_params
195+
num_params = product(shape(self % kernel)) + size(self % biases)
196+
end function get_num_params
197+
198+
199+
pure module function get_params(self) result(params)
200+
class(conv2d_layer), intent(in) :: self
201+
real, allocatable :: params(:)
202+
203+
params = [ &
204+
pack(self % kernel, .true.), &
205+
pack(self % biases, .true.) &
206+
]
207+
208+
end function get_params
209+
210+
211+
module subroutine set_params(self, params)
212+
class(conv2d_layer), intent(in out) :: self
213+
real, intent(in) :: params(:)
214+
215+
! Check that the number of parameters is correct.
216+
if (size(params) /= self % get_num_params()) then
217+
error stop 'conv2d % set_params: Number of parameters does not match'
218+
end if
219+
220+
! Reshape the kernel.
221+
self % kernel = reshape( &
222+
params(:product(shape(self % kernel))), &
223+
shape(self % kernel) &
224+
)
225+
226+
! Reshape the biases.
227+
self % biases = reshape( &
228+
params(product(shape(self % kernel)) + 1:), &
229+
[self % filters] &
230+
)
231+
232+
end subroutine set_params
233+
234+
191235
elemental module subroutine set_activation(self, activation)
192236
class(conv2d_layer), intent(in out) :: self
193237
character(*), intent(in) :: activation

src/nf/nf_dense_layer.f90

+29
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ module nf_dense_layer
3636

3737
procedure :: backward
3838
procedure :: forward
39+
procedure :: get_num_params
40+
procedure :: get_params
41+
procedure :: set_params
3942
procedure :: init
4043
procedure :: set_activation
4144
procedure :: update
@@ -80,6 +83,32 @@ pure module subroutine forward(self, input)
8083
!! Input from the previous layer
8184
end subroutine forward
8285

86+
pure module function get_num_params(self) result(num_params)
87+
!! Return the number of parameters in this layer.
88+
class(dense_layer), intent(in) :: self
89+
!! Dense layer instance
90+
integer :: num_params
91+
!! Number of parameters in this layer
92+
end function get_num_params
93+
94+
pure module function get_params(self) result(params)
95+
!! Return the parameters of this layer.
96+
!! The parameters are ordered as weights first, biases second.
97+
class(dense_layer), intent(in) :: self
98+
!! Dense layer instance
99+
real, allocatable :: params(:)
100+
!! Parameters of this layer
101+
end function get_params
102+
103+
module subroutine set_params(self, params)
104+
!! Set the parameters of this layer.
105+
!! The parameters are ordered as weights first, biases second.
106+
class(dense_layer), intent(in out) :: self
107+
!! Dense layer instance
108+
real, intent(in) :: params(:)
109+
!! Parameters of this layer
110+
end subroutine set_params
111+
83112
module subroutine init(self, input_shape)
84113
!! Initialize the layer data structures.
85114
!!

src/nf/nf_dense_layer_submodule.f90

+46
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,52 @@ pure module subroutine forward(self, input)
5353
end subroutine forward
5454

5555

56+
pure module function get_num_params(self) result(num_params)
57+
class(dense_layer), intent(in) :: self
58+
integer :: num_params
59+
60+
! Number of weigths times number of biases
61+
num_params = self % input_size * self % output_size + self % output_size
62+
63+
end function get_num_params
64+
65+
66+
pure module function get_params(self) result(params)
67+
class(dense_layer), intent(in) :: self
68+
real, allocatable :: params(:)
69+
70+
params = [ &
71+
pack(self % weights, .true.), &
72+
pack(self % biases, .true.) &
73+
]
74+
75+
end function get_params
76+
77+
78+
module subroutine set_params(self, params)
79+
class(dense_layer), intent(in out) :: self
80+
real, intent(in) :: params(:)
81+
82+
! check if the number of parameters is correct
83+
if (size(params) /= self % get_num_params()) then
84+
error stop 'Error: number of parameters does not match'
85+
end if
86+
87+
! reshape the weights
88+
self % weights = reshape( &
89+
params(:self % input_size * self % output_size), &
90+
[self % input_size, self % output_size] &
91+
)
92+
93+
! reshape the biases
94+
self % biases = reshape( &
95+
params(self % input_size * self % output_size + 1:), &
96+
[self % output_size] &
97+
)
98+
99+
end subroutine set_params
100+
101+
56102
module subroutine init(self, input_shape)
57103
class(dense_layer), intent(in out) :: self
58104
integer, intent(in) :: input_shape(:)

src/nf/nf_layer.f90

+27
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ module nf_layer
2525
contains
2626

2727
procedure :: forward
28+
procedure :: get_num_params
29+
procedure :: get_params
30+
procedure :: set_params
2831
procedure :: init
2932
procedure :: print_info
3033
procedure :: update
@@ -117,6 +120,30 @@ impure elemental module subroutine print_info(self)
117120
!! Layer instance
118121
end subroutine print_info
119122

123+
elemental module function get_num_params(self) result(num_params)
124+
!! Returns the number of parameters in this layer.
125+
class(layer), intent(in) :: self
126+
!! Layer instance
127+
integer :: num_params
128+
!! Number of parameters in this layer
129+
end function get_num_params
130+
131+
pure module function get_params(self) result(params)
132+
!! Returns the parameters of this layer.
133+
class(layer), intent(in) :: self
134+
!! Layer instance
135+
real, allocatable :: params(:)
136+
!! Parameters of this layer
137+
end function get_params
138+
139+
module subroutine set_params(self, params)
140+
!! Returns the parameters of this layer.
141+
class(layer), intent(in out) :: self
142+
!! Layer instance
143+
real, intent(in) :: params(:)
144+
!! Parameters of this layer
145+
end subroutine set_params
146+
120147
impure elemental module subroutine update(self, learning_rate)
121148
!! Update the weights and biases on the layer using the stored
122149
!! gradients (from backward passes), and flush those same stored

0 commit comments

Comments
 (0)