1
1
program test_conv2d_network
2
2
3
3
use iso_fortran_env, only: stderr = > error_unit
4
- use nf, only: conv2d, input, network
4
+ use nf, only: conv2d, input, network, dense, sgd, maxpool2d
5
5
6
6
implicit none
7
7
@@ -21,6 +21,7 @@ program test_conv2d_network
21
21
ok = .false.
22
22
end if
23
23
24
+ ! Test for output shape
24
25
allocate (sample_input(3 , 32 , 32 ))
25
26
sample_input = 0
26
27
@@ -32,6 +33,115 @@ program test_conv2d_network
32
33
ok = .false.
33
34
end if
34
35
36
+ deallocate (sample_input, output)
37
+
38
+ training1: block
39
+
40
+ type (network) :: cnn
41
+ real :: y(1 )
42
+ real :: tolerance = 1e-5
43
+ integer :: n
44
+ integer , parameter :: num_iterations = 1000
45
+
46
+ ! Test training of a minimal constant mapping
47
+ allocate (sample_input(1 , 5 , 5 ))
48
+ call random_number (sample_input)
49
+
50
+ cnn = network([ &
51
+ input(shape (sample_input)), &
52
+ conv2d(filters= 1 , kernel_size= 3 ), &
53
+ conv2d(filters= 1 , kernel_size= 3 ), &
54
+ dense(1 ) &
55
+ ])
56
+
57
+ y = [0.1234567 ]
58
+
59
+ do n = 1 , num_iterations
60
+ call cnn % forward(sample_input)
61
+ call cnn % backward(y)
62
+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
63
+ if (all (abs (cnn % predict(sample_input) - y) < tolerance)) exit
64
+ end do
65
+
66
+ if (.not. n <= num_iterations) then
67
+ write (stderr, ' (a)' ) &
68
+ ' convolutional network 1 should converge in simple training.. failed'
69
+ ok = .false.
70
+ end if
71
+
72
+ end block training1
73
+
74
+ training2: block
75
+
76
+ type (network) :: cnn
77
+ real :: x(1 , 8 , 8 )
78
+ real :: y(1 )
79
+ real :: tolerance = 1e-5
80
+ integer :: n
81
+ integer , parameter :: num_iterations = 1000
82
+
83
+ call random_number (x)
84
+ y = [0.1234567 ]
85
+
86
+ cnn = network([ &
87
+ input(shape (x)), &
88
+ conv2d(filters= 1 , kernel_size= 3 ), &
89
+ maxpool2d(pool_size= 2 ), &
90
+ conv2d(filters= 1 , kernel_size= 3 ), &
91
+ dense(1 ) &
92
+ ])
93
+
94
+ do n = 1 , num_iterations
95
+ call cnn % forward(x)
96
+ call cnn % backward(y)
97
+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
98
+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
99
+ end do
100
+
101
+ if (.not. n <= num_iterations) then
102
+ write (stderr, ' (a)' ) &
103
+ ' convolutional network 2 should converge in simple training.. failed'
104
+ ok = .false.
105
+ end if
106
+
107
+ end block training2
108
+
109
+ training3: block
110
+
111
+ type (network) :: cnn
112
+ real :: x(1 , 12 , 12 )
113
+ real :: y(9 )
114
+ real :: tolerance = 1e-5
115
+ integer :: n
116
+ integer , parameter :: num_iterations = 5000
117
+
118
+ call random_number (x)
119
+ y = [0.12345 , 0.23456 , 0.34567 , 0.45678 , 0.56789 , 0.67890 , 0.78901 , 0.89012 , 0.90123 ]
120
+
121
+ cnn = network([ &
122
+ input(shape (x)), &
123
+ conv2d(filters= 1 , kernel_size= 3 ), & ! 1x12x12 input, 1x10x10 output
124
+ maxpool2d(pool_size= 2 ), & ! 1x10x10 input, 1x5x5 output
125
+ conv2d(filters= 1 , kernel_size= 3 ), & ! 1x5x5 input, 1x3x3 output
126
+ dense(9 ) & ! 9 outputs
127
+ ])
128
+
129
+ do n = 1 , num_iterations
130
+ call cnn % forward(x)
131
+ call cnn % backward(y)
132
+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
133
+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
134
+ end do
135
+
136
+ if (.not. n <= num_iterations) then
137
+ write (stderr, ' (a)' ) &
138
+ ' convolutional network 3 should converge in simple training.. failed'
139
+ ok = .false.
140
+ end if
141
+
142
+ end block training3
143
+
144
+
35
145
if (ok) then
36
146
print ' (a)' , ' test_conv2d_network: All tests passed.'
37
147
else
0 commit comments