Skip to content

Commit e51a746

Browse files
authored
Insert a flatten layer if a dense layer follows a layer (#118)
* Insert flatten if dense layer follows a layer with 3-d output * Bump version * Update copyright year
1 parent edd3f70 commit e51a746

File tree

6 files changed

+96
-5
lines changed

6 files changed

+96
-5
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2018-2022 neural-fortran contributors
3+
Copyright (c) 2018-2023 neural-fortran contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

example/cnn_mnist.f90

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ program cnn_mnist
2828
maxpool2d(pool_size=2), &
2929
conv2d(filters=16, kernel_size=3, activation='relu'), &
3030
maxpool2d(pool_size=2), &
31-
flatten(), &
3231
dense(10, activation='softmax') &
3332
])
3433

fpm.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name = "neural-fortran"
2-
version = "0.10.0"
2+
version = "0.11.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"
6-
copyright = "Copyright 2018-2022, neural-fortran contributors"
6+
copyright = "Copyright 2018-2023, neural-fortran contributors"
77

88
[build]
99
external-modules = "hdf5"

src/nf/nf_network_submodule.f90

+28-1
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,37 @@ module function network_from_layers(layers) result(res)
4545

4646
res % layers = layers
4747

48+
! If connecting a 3-d output layer to a 1-d input layer without a flatten
49+
! layer in between, insert a flatten layer.
50+
n = 2
51+
do while (n <= size(res % layers))
52+
select type(this_layer => res % layers(n) % p)
53+
type is(dense_layer)
54+
select type(prev_layer => res % layers(n-1) % p)
55+
type is(input3d_layer)
56+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
57+
n = n + 1
58+
type is(conv2d_layer)
59+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
60+
n = n + 1
61+
type is(maxpool2d_layer)
62+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
63+
n = n + 1
64+
type is(reshape3d_layer)
65+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
66+
n = n + 1
67+
class default
68+
n = n + 1
69+
end select
70+
class default
71+
n = n + 1
72+
end select
73+
end do
74+
4875
! Loop over each layer in order and call the init methods.
4976
! This will allocate the data internal to each layer (e.g. weights, biases)
5077
! according to the size of the previous layer.
51-
do n = 2, size(layers)
78+
do n = 2, size(res % layers)
5279
call res % layers(n) % init(res % layers(n - 1))
5380
end do
5481

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ foreach(execid
55
conv2d_layer
66
maxpool2d_layer
77
flatten_layer
8+
insert_flatten
89
reshape_layer
910
dense_network
1011
get_set_network_params

test/test_insert_flatten.f90

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
program test_insert_flatten
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: network, input, conv2d, maxpool2d, flatten, dense, reshape
5+
6+
implicit none
7+
8+
type(network) :: net
9+
logical :: ok = .true.
10+
11+
net = network([ &
12+
input([3, 32, 32]), &
13+
dense(10) &
14+
])
15+
16+
if (.not. net % layers(2) % name == 'flatten') then
17+
ok = .false.
18+
write(stderr, '(a)') 'flatten layer inserted after input3d.. failed'
19+
end if
20+
21+
net = network([ &
22+
input([3, 32, 32]), &
23+
conv2d(filters=1, kernel_size=3), &
24+
dense(10) &
25+
])
26+
27+
!call net % print_info()
28+
29+
if (.not. net % layers(3) % name == 'flatten') then
30+
ok = .false.
31+
write(stderr, '(a)') 'flatten layer inserted after conv2d.. failed'
32+
end if
33+
34+
net = network([ &
35+
input([3, 32, 32]), &
36+
conv2d(filters=1, kernel_size=3), &
37+
maxpool2d(pool_size=2, stride=2), &
38+
dense(10) &
39+
])
40+
41+
if (.not. net % layers(4) % name == 'flatten') then
42+
ok = .false.
43+
write(stderr, '(a)') 'flatten layer inserted after maxpool2d.. failed'
44+
end if
45+
46+
net = network([ &
47+
input(4), &
48+
reshape([1, 2, 2]), &
49+
dense(4) &
50+
])
51+
52+
if (.not. net % layers(3) % name == 'flatten') then
53+
ok = .false.
54+
write(stderr, '(a)') 'flatten layer inserted after reshape.. failed'
55+
end if
56+
57+
if (ok) then
58+
print '(a)', 'test_insert_flatten: All tests passed.'
59+
else
60+
write(stderr, '(a)') 'test_insert_flatten: One or more tests failed.'
61+
stop 1
62+
end if
63+
64+
end program test_insert_flatten

0 commit comments

Comments
 (0)