@@ -35,12 +35,14 @@ module nf_layer
3535
3636 ! Specific subroutines for different array ranks
3737 procedure , private :: backward_1d
38+ procedure , private :: backward_2d
3839 procedure , private :: backward_3d
3940 procedure , private :: get_output_1d
41+ procedure , private :: get_output_2d
4042 procedure , private :: get_output_3d
4143
42- generic :: backward = > backward_1d, backward_3d
43- generic :: get_output = > get_output_1d, get_output_3d
44+ generic :: backward = > backward_1d, backward_2d, backward_3d
45+ generic :: get_output = > get_output_1d, get_output_2d, get_output_3d
4446
4547 end type layer
4648
@@ -59,6 +61,19 @@ pure module subroutine backward_1d(self, previous, gradient)
5961 ! ! Array of gradient values from the next layer
6062 end subroutine backward_1d
6163
64+ pure module subroutine backward_2d(self, previous, gradient)
65+ ! ! Apply a backward pass on the layer.
66+ ! ! This changes the internal state of the layer.
67+ ! ! This is normally called internally by the `network % backward`
68+ ! ! method.
69+ class(layer), intent (in out ) :: self
70+ ! ! Layer instance
71+ class(layer), intent (in ) :: previous
72+ ! ! Previous layer instance
73+ real , intent (in ) :: gradient(:, :)
74+ ! ! Array of gradient values from the next layer
75+ end subroutine backward_2d
76+
6277 pure module subroutine backward_3d(self, previous, gradient)
6378 ! ! Apply a backward pass on the layer.
6479 ! ! This changes the internal state of the layer.
@@ -95,6 +110,14 @@ pure module subroutine get_output_1d(self, output)
95110 ! ! Output values from this layer
96111 end subroutine get_output_1d
97112
113+ pure module subroutine get_output_2d(self, output)
114+ ! ! Returns the output values (activations) from this layer.
115+ class(layer), intent (in ) :: self
116+ ! ! Layer instance
117+ real , allocatable , intent (out ) :: output(:,:)
118+ ! ! Output values from this layer
119+ end subroutine get_output_2d
120+
98121 pure module subroutine get_output_3d(self, output)
99122 ! ! Returns the output values (activations) from a layer with a 3-d output
100123 ! ! (e.g. input3d, conv2d)
0 commit comments