Skip to content

Commit fa73fb7

Browse files
committed
Test that we can chain input3d with dense using flatten
1 parent 8006f9b commit fa73fb7

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

test/test_flatten_layer.f90

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
program test_flatten_layer
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: flatten, input, layer
4+
use nf, only: dense, flatten, input, layer, network
55
use nf_flatten_layer, only: flatten_layer
66
use nf_input3d_layer, only: input3d_layer
77

88
implicit none
99

1010
type(layer) :: test_layer, input_layer
11+
type(network) :: net
1112
real, allocatable :: input_data(:,:,:), gradient(:,:,:)
1213
real, allocatable :: output(:)
1314
logical :: ok = .true.
@@ -66,6 +67,18 @@ program test_flatten_layer
6667
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
6768
end if
6869

70+
net = network([ &
71+
input([1, 28, 28]), &
72+
flatten(), &
73+
dense(10) &
74+
])
75+
76+
! Test that the output layer receives 784 elements in the input
77+
if (.not. all(net % layers(3) % input_layer_shape == [784])) then
78+
ok = .false.
79+
write(stderr, '(a)') 'flatten layer correctly chains input3d to dense.. failed'
80+
end if
81+
6982
if (ok) then
7083
print '(a)', 'test_flatten_layer: All tests passed.'
7184
else

0 commit comments

Comments
 (0)