Skip to content

Commit ad592da

Browse files
committed
Locally connected still not working
1 parent 9ae50ee commit ad592da

7 files changed

+238
-108
lines changed

example/cnn_mnist_1d.f90

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ program cnn_mnist
2020

2121
net = network([ &
2222
input(784), &
23-
reshape_generalized([1,784]), &
23+
reshape_generalized([28,28]), &
2424
locally_connected_1d(filters=8, kernel_size=3, activation=relu()), &
2525
maxpool1d(pool_size=2), &
2626
locally_connected_1d(filters=16, kernel_size=3, activation=relu()), &
@@ -36,8 +36,8 @@ program cnn_mnist
3636
training_images, &
3737
label_digits(training_labels), &
3838
batch_size=16, &
39-
epochs=1, &
40-
optimizer=sgd(learning_rate=0.1) &
39+
epochs=5, &
40+
optimizer=sgd(learning_rate=0.003) &
4141
)
4242

4343
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &

src/nf/nf_layer_submodule.f90

+52-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use nf_input2d_layer, only: input2d_layer
99
use nf_input3d_layer, only: input3d_layer
1010
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
11+
use nf_maxpool1d_layer, only: maxpool1d_layer
1112
use nf_maxpool2d_layer, only: maxpool2d_layer
1213
use nf_reshape_layer, only: reshape3d_layer
1314
use nf_reshape_layer_generalized, only: reshape_generalized_layer
@@ -61,7 +62,29 @@ pure module subroutine backward_2d(self, previous, gradient)
6162

6263
! Backward pass from a 2-d layer downstream currently implemented
6364
! only for dense and flatten layers
64-
! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
65+
66+
select type(this_layer => self % p)
67+
68+
type is(locally_connected_1d_layer)
69+
70+
select type(prev_layer => previous % p)
71+
type is(maxpool1d_layer)
72+
call this_layer % backward(prev_layer % output, gradient)
73+
type is(locally_connected_1d_layer)
74+
call this_layer % backward(prev_layer % output, gradient)
75+
end select
76+
77+
type is(maxpool1d_layer)
78+
79+
select type(prev_layer => previous % p)
80+
type is(maxpool1d_layer)
81+
call this_layer % backward(prev_layer % output, gradient)
82+
type is(locally_connected_1d_layer)
83+
call this_layer % backward(prev_layer % output, gradient)
84+
end select
85+
86+
end select
87+
6588
end subroutine backward_2d
6689

6790

@@ -152,6 +175,15 @@ pure module subroutine forward(self, input)
152175
type is(reshape3d_layer)
153176
call this_layer % forward(prev_layer % output)
154177
end select
178+
179+
type is(maxpool1d_layer)
180+
181+
select type(prev_layer => input % p)
182+
type is(locally_connected_1d_layer)
183+
call this_layer % forward(prev_layer % output)
184+
type is(maxpool1d_layer)
185+
call this_layer % forward(prev_layer % output)
186+
end select
155187

156188
type is(maxpool2d_layer)
157189

@@ -211,6 +243,8 @@ pure module subroutine get_output_1d(self, output)
211243
allocate(output, source=this_layer % output)
212244
type is(flatten_layer)
213245
allocate(output, source=this_layer % output)
246+
type is(reshape_generalized_layer)
247+
allocate(output, source=this_layer % output)
214248
class default
215249
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
216250

@@ -227,8 +261,12 @@ pure module subroutine get_output_2d(self, output)
227261

228262
type is(input2d_layer)
229263
allocate(output, source=this_layer % output)
264+
type is(maxpool1d_layer)
265+
allocate(output, source=this_layer % output)
230266
type is(locally_connected_1d_layer)
231267
allocate(output, source=this_layer % output)
268+
!type is(reshape_generalized_layer)
269+
!allocate(output, source=this_layer % output)
232270
class default
233271
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
234272

@@ -279,6 +317,8 @@ impure elemental module subroutine init(self, input)
279317
self % layer_shape = shape(this_layer % output)
280318
type is(locally_connected_1d_layer)
281319
self % layer_shape = shape(this_layer % output)
320+
type is(maxpool1d_layer)
321+
self % layer_shape = shape(this_layer % output)
282322
type is(maxpool2d_layer)
283323
self % layer_shape = shape(this_layer % output)
284324
type is(flatten_layer)
@@ -324,6 +364,8 @@ elemental module function get_num_params(self) result(num_params)
324364
num_params = this_layer % get_num_params()
325365
type is (locally_connected_1d_layer)
326366
num_params = this_layer % get_num_params()
367+
type is(maxpool1d_layer)
368+
num_params = 0
327369
type is (maxpool2d_layer)
328370
num_params = 0
329371
type is (flatten_layer)
@@ -355,6 +397,8 @@ module function get_params(self) result(params)
355397
params = this_layer % get_params()
356398
type is (locally_connected_1d_layer)
357399
params = this_layer % get_params()
400+
type is (maxpool1d_layer)
401+
! No parameters to get.
358402
type is (maxpool2d_layer)
359403
! No parameters to get.
360404
type is (flatten_layer)
@@ -386,6 +430,8 @@ module function get_gradients(self) result(gradients)
386430
gradients = this_layer % get_gradients()
387431
type is (locally_connected_1d_layer)
388432
gradients = this_layer % get_gradients()
433+
type is (maxpool1d_layer)
434+
! No gradients to get.
389435
type is (maxpool2d_layer)
390436
! No gradients to get.
391437
type is (flatten_layer)
@@ -443,6 +489,11 @@ module subroutine set_params(self, params)
443489

444490
type is (locally_connected_1d_layer)
445491
call this_layer % set_params(params)
492+
493+
type is (maxpool1d_layer)
494+
! No parameters to set.
495+
write(stderr, '(a)') 'Warning: calling set_params() ' &
496+
// 'on a zero-parameter layer; nothing to do.'
446497

447498
type is (maxpool2d_layer)
448499
! No parameters to set.

src/nf/nf_locally_connected_1d_submodule.f90

+88-26
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ module subroutine init(self, input_shape)
3636
! Kernel of shape filters x channels x kernel_size
3737
allocate(self % kernel(self % filters, self % channels, self % kernel_size))
3838

39-
! Initialize the kernel with random values
39+
! Initialize the kernel with random values with a normal distribution
4040
call random_normal(self % kernel)
41-
self % kernel = self % kernel / self % kernel_size
41+
self % kernel = self % kernel / self % kernel_size ** 2
4242

4343
allocate(self % biases(self % filters))
4444
self % biases = 0
@@ -62,50 +62,112 @@ pure module subroutine forward(self, input)
6262
class(locally_connected_1d_layer), intent(in out) :: self
6363
real, intent(in) :: input(:,:)
6464
integer :: input_width, input_channels
65-
integer :: i, n
65+
integer :: i, n, i_out
66+
integer :: iws, iwe
67+
integer :: half_window
6668

69+
! Get input dimensions
6770
input_channels = size(input, dim=1)
68-
input_width = size(input, dim=2)
71+
input_width = size(input, dim=2)
6972

70-
do concurrent (i = 1:self % width)
71-
do concurrent (n = 1:self % filters)
72-
self % z(n,i) = sum(self % kernel(n,:,:)* input(:,i:i+self % kernel_size-1))
73+
! For a kernel of odd size, half_window = kernel_size / 2 (integer division)
74+
half_window = self % kernel_size / 2
75+
76+
! Loop over output indices rather than input indices.
77+
do i_out = 1, self % width
78+
! Compute the corresponding center index in the input.
79+
i = i_out + half_window
80+
81+
! Define the window in the input corresponding to the filter kernel
82+
iws = i - half_window
83+
iwe = i + half_window
84+
85+
! Compute the inner tensor product (sum of element-wise products)
86+
! for each filter across all channels and positions in the kernel.
87+
do concurrent(n = 1:self % filters)
88+
self % z(n, i_out) = sum(self % kernel(n, :, :) * input(:, iws:iwe))
7389
end do
74-
end do
7590

76-
! Add bias
77-
self % z = self % z + reshape(self % biases, shape(self % z))
91+
! Add the bias for each filter.
92+
self % z(:, i_out) = self % z(:, i_out) + self % biases
93+
end do
7894

79-
! Apply activation
95+
! Apply the activation function to get the final output.
8096
self % output = self % activation % eval(self % z)
81-
8297
end subroutine forward
8398

99+
84100
pure module subroutine backward(self, input, gradient)
85101
implicit none
86102
class(locally_connected_1d_layer), intent(in out) :: self
87-
real, intent(in) :: input(:,:)
88-
real, intent(in) :: gradient(:,:)
103+
real, intent(in) :: input(:,:) ! shape: (channels, width)
104+
real, intent(in) :: gradient(:,:) ! shape: (filters, width)
105+
106+
! Local gradient arrays:
89107
real :: db(self % filters)
90108
real :: dw(self % filters, self % channels, self % kernel_size)
91-
real :: gdz(self % filters, self % width)
109+
real :: gdz(self % filters, size(input, 2))
110+
92111
integer :: i, n, k
93-
94-
gdz = gradient * self % activation % eval_prime(self % z)
95-
112+
integer :: input_channels, input_width
113+
integer :: istart, iend
114+
integer :: iws, iwe
115+
integer :: half_window
116+
117+
! Get input dimensions.
118+
input_channels = size(input, dim=1)
119+
input_width = size(input, dim=2)
120+
121+
! For an odd-sized kernel, half_window = kernel_size / 2.
122+
half_window = self % kernel_size / 2
123+
124+
! Define the valid output range so that the full input window is available.
125+
istart = half_window + 1
126+
iend = input_width - half_window
127+
128+
!---------------------------------------------------------------------
129+
! Compute the local gradient: gdz = (dL/dy) * sigma'(z)
130+
! We assume self%z stores the pre-activation values from the forward pass.
131+
gdz = 0.0
132+
gdz(:, istart:iend) = gradient(:, istart:iend) * self % activation % eval_prime(self % z(:, istart:iend))
133+
134+
!---------------------------------------------------------------------
135+
! Compute gradient with respect to biases:
136+
! dL/db(n) = sum_{i in valid range} gdz(n, i)
96137
do concurrent (n = 1:self % filters)
97-
db(n) = sum(gdz(n,:))
138+
db(n) = sum(gdz(n, istart:iend))
98139
end do
99-
100-
dw = 0
101-
self % gradient = 0
102-
do concurrent (n = 1:self % filters, k = 1:self % channels, i = 1:self % width)
103-
dw(n,k,:) = dw(n,k,:) + input(k, i:i+self % kernel_size-1) * gdz(n, i)
140+
141+
! Initialize weight gradient and input gradient accumulators.
142+
dw = 0.0
143+
self % gradient = 0.0 ! This array is assumed preallocated to shape (channels, width)
144+
145+
!---------------------------------------------------------------------
146+
! Accumulate gradients over valid output positions.
147+
! For each output position i, determine the corresponding input window indices.
148+
do concurrent (n = 1:self % filters, &
149+
k = 1:self % channels, &
150+
i = istart:iend)
151+
! The input window corresponding to output index i:
152+
iws = i - half_window
153+
iwe = i + half_window
154+
155+
! Weight gradient (dL/dw):
156+
! For each kernel element, the contribution is the product of the input in the window
157+
! and the local gradient at the output position i.
158+
dw(n, k, :) = dw(n, k, :) + input(k, iws:iwe) * gdz(n, i)
159+
160+
! Input gradient (dL/dx):
161+
! Distribute the effect of the output gradient back onto the input window,
162+
! weighted by the kernel weights.
163+
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, i)
104164
end do
105-
165+
166+
!---------------------------------------------------------------------
167+
! Accumulate the computed gradients into the layer's stored gradients.
106168
self % dw = self % dw + dw
107169
self % db = self % db + db
108-
170+
109171
end subroutine backward
110172

111173
pure module function get_num_params(self) result(num_params)

src/nf/nf_network_submodule.f90

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ module function network_from_layers(layers) result(res)
7373
type is(conv2d_layer)
7474
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
7575
n = n + 1
76-
!type is(locally_connected_1d_layer)
77-
!res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
78-
!n = n + 1
76+
type is(locally_connected_1d_layer)
77+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
78+
n = n + 1
7979
type is(maxpool2d_layer)
8080
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
8181
n = n + 1

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ foreach(execid
22
input1d_layer
33
input2d_layer
44
input3d_layer
5+
locally_connected_1d_layer
56
parametric_activation
67
dense_layer
78
conv2d_layer

0 commit comments

Comments
 (0)