Skip to content

Commit ca219cd

Browse files
authored
More conv2d tests (#174)
* Test training a simple conv2d network * Default activation ReLU for conv2d * Adjust test for new default activation * flatten() layer unnecessary; inserted automatically * Add a connection to the reshape layer in network % backward(); possible culprit of CNN bug * Test CNN with maxpool2d in the loop * Intermediate complexity CNN training converges * Update note in README regarding CNN training
1 parent b1b2cac commit ca219cd

5 files changed

+122
-11
lines changed

README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
2929
|------------|------------------|------------------------|----------------------|--------------|---------------|
3030
| Input | `input` | n/a | 1, 3 | n/a | n/a |
3131
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
32-
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || |
32+
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |
3333
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
3434
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
3535
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3636

37-
**Note:** The training of convolutional layers has been discovered to be broken
38-
as of release 0.13.0. This will be fixed in a future (hopefully next) release.
37+
(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.
3938

4039
## Getting started
4140

src/nf/nf_layer_constructors_submodule.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use nf_input3d_layer, only: input3d_layer
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
11-
use nf_activation, only: activation_function, sigmoid
11+
use nf_activation, only: activation_function, relu, sigmoid
1212

1313
implicit none
1414

@@ -27,7 +27,7 @@ pure module function conv2d(filters, kernel_size, activation) result(res)
2727
if (present(activation)) then
2828
allocate(activation_tmp, source=activation)
2929
else
30-
allocate(activation_tmp, source=sigmoid())
30+
allocate(activation_tmp, source=relu())
3131
end if
3232

3333
res % activation = activation_tmp % get_name()

src/nf/nf_network_submodule.f90

+4-2
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,14 @@ pure module subroutine backward(self, output)
305305
select type(next_layer => self % layers(n + 1) % p)
306306
type is(dense_layer)
307307
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
308-
type is(flatten_layer)
309-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
310308
type is(conv2d_layer)
311309
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
310+
type is(flatten_layer)
311+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
312312
type is(maxpool2d_layer)
313313
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
314+
type is(reshape3d_layer)
315+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
314316
end select
315317
end if
316318

test/test_conv2d_layer.f90

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ program test_conv2d_layer
2424
write(stderr, '(a)') 'conv2d layer should not be marked as initialized yet.. failed'
2525
end if
2626

27-
if (.not. conv_layer % activation == 'sigmoid') then
27+
if (.not. conv_layer % activation == 'relu') then
2828
ok = .false.
29-
write(stderr, '(a)') 'conv2d layer is defaults to sigmoid activation.. failed'
29+
write(stderr, '(a)') 'conv2d layer defaults to relu activation.. failed'
3030
end if
3131

3232
input_layer = input([3, 32, 32])
@@ -62,7 +62,7 @@ program test_conv2d_layer
6262
call conv_layer % forward(input_layer)
6363
call conv_layer % get_output(output)
6464

65-
if (.not. all(abs(output - 0.5) < tolerance)) then
65+
if (.not. all(abs(output) < tolerance)) then
6666
ok = .false.
6767
write(stderr, '(a)') 'conv2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
6868
end if

test/test_conv2d_network.f90

+111-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_conv2d_network
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: conv2d, input, network
4+
use nf, only: conv2d, input, network, dense, sgd, maxpool2d
55

66
implicit none
77

@@ -21,6 +21,7 @@ program test_conv2d_network
2121
ok = .false.
2222
end if
2323

24+
! Test for output shape
2425
allocate(sample_input(3, 32, 32))
2526
sample_input = 0
2627

@@ -32,6 +33,115 @@ program test_conv2d_network
3233
ok = .false.
3334
end if
3435

36+
deallocate(sample_input, output)
37+
38+
training1: block
39+
40+
type(network) :: cnn
41+
real :: y(1)
42+
real :: tolerance = 1e-5
43+
integer :: n
44+
integer, parameter :: num_iterations = 1000
45+
46+
! Test training of a minimal constant mapping
47+
allocate(sample_input(1, 5, 5))
48+
call random_number(sample_input)
49+
50+
cnn = network([ &
51+
input(shape(sample_input)), &
52+
conv2d(filters=1, kernel_size=3), &
53+
conv2d(filters=1, kernel_size=3), &
54+
dense(1) &
55+
])
56+
57+
y = [0.1234567]
58+
59+
do n = 1, num_iterations
60+
call cnn % forward(sample_input)
61+
call cnn % backward(y)
62+
call cnn % update(optimizer=sgd(learning_rate=1.))
63+
if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit
64+
end do
65+
66+
if (.not. n <= num_iterations) then
67+
write(stderr, '(a)') &
68+
'convolutional network 1 should converge in simple training.. failed'
69+
ok = .false.
70+
end if
71+
72+
end block training1
73+
74+
training2: block
75+
76+
type(network) :: cnn
77+
real :: x(1, 8, 8)
78+
real :: y(1)
79+
real :: tolerance = 1e-5
80+
integer :: n
81+
integer, parameter :: num_iterations = 1000
82+
83+
call random_number(x)
84+
y = [0.1234567]
85+
86+
cnn = network([ &
87+
input(shape(x)), &
88+
conv2d(filters=1, kernel_size=3), &
89+
maxpool2d(pool_size=2), &
90+
conv2d(filters=1, kernel_size=3), &
91+
dense(1) &
92+
])
93+
94+
do n = 1, num_iterations
95+
call cnn % forward(x)
96+
call cnn % backward(y)
97+
call cnn % update(optimizer=sgd(learning_rate=1.))
98+
if (all(abs(cnn % predict(x) - y) < tolerance)) exit
99+
end do
100+
101+
if (.not. n <= num_iterations) then
102+
write(stderr, '(a)') &
103+
'convolutional network 2 should converge in simple training.. failed'
104+
ok = .false.
105+
end if
106+
107+
end block training2
108+
109+
training3: block
110+
111+
type(network) :: cnn
112+
real :: x(1, 12, 12)
113+
real :: y(9)
114+
real :: tolerance = 1e-5
115+
integer :: n
116+
integer, parameter :: num_iterations = 5000
117+
118+
call random_number(x)
119+
y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789, 0.67890, 0.78901, 0.89012, 0.90123]
120+
121+
cnn = network([ &
122+
input(shape(x)), &
123+
conv2d(filters=1, kernel_size=3), & ! 1x12x12 input, 1x10x10 output
124+
maxpool2d(pool_size=2), & ! 1x10x10 input, 1x5x5 output
125+
conv2d(filters=1, kernel_size=3), & ! 1x5x5 input, 1x3x3 output
126+
dense(9) & ! 9 outputs
127+
])
128+
129+
do n = 1, num_iterations
130+
call cnn % forward(x)
131+
call cnn % backward(y)
132+
call cnn % update(optimizer=sgd(learning_rate=1.))
133+
if (all(abs(cnn % predict(x) - y) < tolerance)) exit
134+
end do
135+
136+
if (.not. n <= num_iterations) then
137+
write(stderr, '(a)') &
138+
'convolutional network 3 should converge in simple training.. failed'
139+
ok = .false.
140+
end if
141+
142+
end block training3
143+
144+
35145
if (ok) then
36146
print '(a)', 'test_conv2d_network: All tests passed.'
37147
else

0 commit comments

Comments
 (0)