1010 use nf_input3d_layer, only: input3d_layer
1111 use nf_maxpool2d_layer, only: maxpool2d_layer
1212 use nf_reshape_layer, only: reshape3d_layer
13+ use nf_linear2d_layer, only: linear2d_layer
1314 use nf_optimizers, only: optimizer_base_type
1415
1516contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5051 call this_layer % backward(prev_layer % output, gradient)
5152 type is (maxpool2d_layer)
5253 call this_layer % backward(prev_layer % output, gradient)
54+ type is (linear2d_layer)
55+ call this_layer % backward(prev_layer % output, gradient)
5356 end select
5457
5558 end select
@@ -63,9 +66,19 @@ pure module subroutine backward_2d(self, previous, gradient)
6366 class(layer), intent (in ) :: previous
6467 real , intent (in ) :: gradient(:,:)
6568
66- ! Backward pass from a 2-d layer downstream currently implemented
67- ! only for dense and flatten layers
68- ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
69+ select type (this_layer = > self % p)
70+
71+ type is (linear2d_layer)
72+
73+ select type (prev_layer = > previous % p)
74+ type is (input2d_layer)
75+ call this_layer % backward(prev_layer % output, gradient)
76+ type is (linear2d_layer)
77+ call this_layer % backward(prev_layer % output, gradient)
78+ end select
79+
80+ end select
81+
6982 end subroutine backward_2d
7083
7184
@@ -199,6 +212,8 @@ module subroutine forward(self, input)
199212 call this_layer % forward(prev_layer % output)
200213 type is (reshape3d_layer)
201214 call this_layer % forward(prev_layer % output)
215+ type is (linear2d_layer)
216+ call this_layer % forward(prev_layer % output)
202217 end select
203218
204219 type is (reshape3d_layer)
@@ -213,6 +228,16 @@ module subroutine forward(self, input)
213228 call this_layer % forward(prev_layer % output)
214229 end select
215230
231+ type is (linear2d_layer)
232+
233+ ! Upstream layers permitted: input2d, linear2d
234+ select type (prev_layer = > input % p)
235+ type is (input2d_layer)
236+ call this_layer % forward(prev_layer % output)
237+ type is (linear2d_layer)
238+ call this_layer % forward(prev_layer % output)
239+ end select
240+
216241 end select
217242
218243 end subroutine forward
@@ -248,8 +273,10 @@ pure module subroutine get_output_2d(self, output)
248273
249274 type is (input2d_layer)
250275 allocate (output, source= this_layer % output)
276+ type is (linear2d_layer)
277+ allocate (output, source= this_layer % output)
251278 class default
252- error stop ' 1 -d output can only be read from an input1d, dense, or flatten layer.'
279+ error stop ' 2 -d output can only be read from an input2d or linear2d layer.'
253280
254281 end select
255282
@@ -291,20 +318,22 @@ impure elemental module subroutine init(self, input)
291318 call this_layer % init(input % layer_shape)
292319 end select
293320
294- ! The shape of conv2d, dropout, flatten, or maxpool2d layers is not known
295- ! until we receive an input layer.
321+ ! The shape of conv2d, dropout, flatten, linear2d, or maxpool2d layers
322+ ! is not known until we receive an input layer.
296323 select type (this_layer = > self % p)
297324 type is (conv2d_layer)
298325 self % layer_shape = shape (this_layer % output)
299326 type is (dropout_layer)
300327 self % layer_shape = shape (this_layer % output)
301328 type is (flatten_layer)
302329 self % layer_shape = shape (this_layer % output)
330+ type is (linear2d_layer)
331+ self % layer_shape = shape (this_layer % output)
303332 type is (maxpool2d_layer)
304333 self % layer_shape = shape (this_layer % output)
305334 end select
306335
307- self % input_layer_shape = input % layer_shape
336+ self % input_layer_shape = input % layer_shape
308337 self % initialized = .true.
309338
310339 end subroutine init
@@ -349,6 +378,8 @@ elemental module function get_num_params(self) result(num_params)
349378 num_params = 0
350379 type is (reshape3d_layer)
351380 num_params = 0
381+ type is (linear2d_layer)
382+ num_params = this_layer % get_num_params()
352383 class default
353384 error stop ' Unknown layer type.'
354385 end select
@@ -378,6 +409,8 @@ module function get_params(self) result(params)
378409 ! No parameters to get.
379410 type is (reshape3d_layer)
380411 ! No parameters to get.
412+ type is (linear2d_layer)
413+ params = this_layer % get_params()
381414 class default
382415 error stop ' Unknown layer type.'
383416 end select
@@ -404,9 +437,11 @@ module function get_gradients(self) result(gradients)
404437 type is (maxpool2d_layer)
405438 ! No gradients to get.
406439 type is (flatten_layer)
407- ! No gradients to get.
440+ ! No parameters to get.
408441 type is (reshape3d_layer)
409442 ! No gradients to get.
443+ type is (linear2d_layer)
444+ gradients = this_layer % get_gradients()
410445 class default
411446 error stop ' Unknown layer type.'
412447 end select
@@ -459,6 +494,9 @@ module subroutine set_params(self, params)
459494 type is (conv2d_layer)
460495 call this_layer % set_params(params)
461496
497+ type is (linear2d_layer)
498+ call this_layer % set_params(params)
499+
462500 type is (maxpool2d_layer)
463501 ! No parameters to set.
464502 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
@@ -476,6 +514,7 @@ module subroutine set_params(self, params)
476514
477515 class default
478516 error stop ' Unknown layer type.'
517+
479518 end select
480519
481520 end subroutine set_params
0 commit comments