Skip to content

Commit 4ad75bc

Browse files
Generic flatten (2d and 3d) (#202)
* Generic flatten() with 2-d and 3-d inputs * Explicitly enable preprocessing for fpm builds * Update README * generic-flatten: use assumed-rank instead of generics --------- Co-authored-by: Mikhail Voronov <[email protected]>
1 parent a28a9be commit 4ad75bc

File tree

7 files changed

+91
-19
lines changed

7 files changed

+91
-19
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3333
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
3434
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |
3535
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
36-
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
36+
| Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
3737
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3838

3939
(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.

fpm.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"
66
copyright = "Copyright 2018-2025, neural-fortran contributors"
7+
8+
[preprocess]
9+
[preprocess.cpp]

src/nf/nf_flatten_layer.f90

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ module nf_flatten_layer
1818
integer, allocatable :: input_shape(:)
1919
integer :: output_size
2020

21-
real, allocatable :: gradient(:,:,:)
21+
real, allocatable :: gradient_2d(:,:)
22+
real, allocatable :: gradient_3d(:,:,:)
2223
real, allocatable :: output(:)
2324

2425
contains
@@ -40,23 +41,23 @@ end function flatten_layer_cons
4041
interface
4142

4243
pure module subroutine backward(self, input, gradient)
43-
!! Apply the backward pass to the flatten layer.
44-
!! This is a reshape operation from 1-d gradient to 3-d input.
44+
!! Apply the backward pass to the flatten layer for 2D and 3D input.
45+
!! This is a reshape operation from 1-d gradient to 2-d and 3-d input.
4546
class(flatten_layer), intent(in out) :: self
4647
!! Flatten layer instance
47-
real, intent(in) :: input(:,:,:)
48+
real, intent(in) :: input(..)
4849
!! Input from the previous layer
4950
real, intent(in) :: gradient(:)
5051
!! Gradient from the next layer
5152
end subroutine backward
5253

5354
pure module subroutine forward(self, input)
54-
!! Propagate forward the layer.
55+
!! Propagate forward the layer for 2D or 3D input.
5556
!! Calling this subroutine updates the values of a few data components
5657
!! of `flatten_layer` that are needed for the backward pass.
5758
class(flatten_layer), intent(in out) :: self
5859
!! Dense layer instance
59-
real, intent(in) :: input(:,:,:)
60+
real, intent(in) :: input(..)
6061
!! Input from the previous layer
6162
end subroutine forward
6263

src/nf/nf_flatten_layer_submodule.f90

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,30 @@ end function flatten_layer_cons
1717

1818
pure module subroutine backward(self, input, gradient)
1919
class(flatten_layer), intent(in out) :: self
20-
real, intent(in) :: input(:,:,:)
20+
real, intent(in) :: input(..)
2121
real, intent(in) :: gradient(:)
22-
self % gradient = reshape(gradient, shape(input))
22+
select rank(input)
23+
rank(2)
24+
self % gradient_2d = reshape(gradient, shape(input))
25+
rank(3)
26+
self % gradient_3d = reshape(gradient, shape(input))
27+
rank default
28+
error stop "Unsupported rank of input"
29+
end select
2330
end subroutine backward
2431

2532

2633
pure module subroutine forward(self, input)
2734
class(flatten_layer), intent(in out) :: self
28-
real, intent(in) :: input(:,:,:)
29-
self % output = pack(input, .true.)
35+
real, intent(in) :: input(..)
36+
select rank(input)
37+
rank(2)
38+
self % output = pack(input, .true.)
39+
rank(3)
40+
self % output = pack(input, .true.)
41+
rank default
42+
error stop "Unsupported rank of input"
43+
end select
3044
end subroutine forward
3145

3246

@@ -37,8 +51,13 @@ module subroutine init(self, input_shape)
3751
self % input_shape = input_shape
3852
self % output_size = product(input_shape)
3953

40-
allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3)))
41-
self % gradient = 0
54+
if (size(input_shape) == 2) then
55+
allocate(self % gradient_2d(input_shape(1), input_shape(2)))
56+
self % gradient_2d = 0
57+
else if (size(input_shape) == 3) then
58+
allocate(self % gradient_3d(input_shape(1), input_shape(2), input_shape(3)))
59+
self % gradient_3d = 0
60+
end if
4261

4362
allocate(self % output(self % output_size))
4463
self % output = 0

src/nf/nf_layer_submodule.f90

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ pure module subroutine backward_1d(self, previous, gradient)
3737

3838
type is(flatten_layer)
3939

40-
! Upstream layers permitted: input3d, conv2d, maxpool2d
40+
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
4141
select type(prev_layer => previous % p)
42+
type is(input2d_layer)
43+
call this_layer % backward(prev_layer % output, gradient)
4244
type is(input3d_layer)
4345
call this_layer % backward(prev_layer % output, gradient)
4446
type is(conv2d_layer)
@@ -168,8 +170,10 @@ pure module subroutine forward(self, input)
168170

169171
type is(flatten_layer)
170172

171-
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
173+
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d, reshape3d
172174
select type(prev_layer => input % p)
175+
type is(input2d_layer)
176+
call this_layer % forward(prev_layer % output)
173177
type is(input3d_layer)
174178
call this_layer % forward(prev_layer % output)
175179
type is(conv2d_layer)

src/nf/nf_network_submodule.f90

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,20 @@ module subroutine backward(self, output, loss)
135135
select type(next_layer => self % layers(n + 1) % p)
136136
type is(dense_layer)
137137
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
138+
138139
type is(conv2d_layer)
139140
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
141+
140142
type is(flatten_layer)
141-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
143+
if (size(self % layers(n) % layer_shape) == 2) then
144+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
145+
else
146+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
147+
end if
148+
142149
type is(maxpool2d_layer)
143150
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
151+
144152
type is(reshape3d_layer)
145153
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
146154
end select

test/test_flatten_layer.f90

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@ program test_flatten_layer
33
use iso_fortran_env, only: stderr => error_unit
44
use nf, only: dense, flatten, input, layer, network
55
use nf_flatten_layer, only: flatten_layer
6+
use nf_input2d_layer, only: input2d_layer
67
use nf_input3d_layer, only: input3d_layer
78

89
implicit none
910

1011
type(layer) :: test_layer, input_layer
1112
type(network) :: net
12-
real, allocatable :: gradient(:,:,:)
13+
real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:)
1314
real, allocatable :: output(:)
1415
logical :: ok = .true.
1516

17+
! Test 3D input
1618
test_layer = flatten()
1719

1820
if (.not. test_layer % name == 'flatten') then
@@ -59,14 +61,49 @@ program test_flatten_layer
5961
call test_layer % backward(input_layer, real([1, 2, 3, 4]))
6062

6163
select type(this_layer => test_layer % p); type is(flatten_layer)
62-
gradient = this_layer % gradient
64+
gradient_3d = this_layer % gradient_3d
6365
end select
6466

65-
if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
67+
if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
6668
ok = .false.
6769
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
6870
end if
6971

72+
! Test 2D input
73+
test_layer = flatten()
74+
input_layer = input(2, 3)
75+
call test_layer % init(input_layer)
76+
77+
if (.not. all(test_layer % layer_shape == [6])) then
78+
ok = .false.
79+
write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed'
80+
end if
81+
82+
! Test forward pass - reshaping from 2-d to 1-d
83+
select type(this_layer => input_layer % p); type is(input2d_layer)
84+
call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))
85+
end select
86+
87+
call test_layer % forward(input_layer)
88+
call test_layer % get_output(output)
89+
90+
if (.not. all(output == [1, 2, 3, 4, 5, 6])) then
91+
ok = .false.
92+
write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed'
93+
end if
94+
95+
! Test backward pass - reshaping from 1-d to 2-d
96+
call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6]))
97+
98+
select type(this_layer => test_layer % p); type is(flatten_layer)
99+
gradient_2d = this_layer % gradient_2d
100+
end select
101+
102+
if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then
103+
ok = .false.
104+
write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed'
105+
end if
106+
70107
net = network([ &
71108
input(1, 28, 28), &
72109
flatten(), &

0 commit comments

Comments
 (0)