Skip to content

Commit 2160f97

Browse files
committed
get_params_ptr and get_gradients_ptr for conv1d, conv2d, and locally_connected1d
1 parent 9d68828 commit 2160f97

7 files changed

+122
-15
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ module nf_conv1d_layer
3232
procedure :: forward
3333
procedure :: backward
3434
procedure :: get_gradients
35+
procedure :: get_gradients_ptr
3536
procedure :: get_num_params
3637
procedure :: get_params
38+
procedure :: get_params_ptr
3739
procedure :: init
3840
procedure :: set_params
3941

@@ -97,6 +99,16 @@ module function get_params(self) result(params)
9799
!! Parameters to get
98100
end function get_params
99101

102+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
103+
!! Return pointers to the parameters (weights and biases) of this layer.
104+
class(conv1d_layer), intent(in), target :: self
105+
!! A `conv1d_layer` instance
106+
real, pointer, intent(out) :: w_ptr(:)
107+
!! Pointer to the kernel weights (flattened)
108+
real, pointer, intent(out) :: b_ptr(:)
109+
!! Pointer to the biases
110+
end subroutine get_params_ptr
111+
100112
module function get_gradients(self) result(gradients)
101113
!! Return the gradients of this layer.
102114
!! The gradients are ordered as weights first, biases second.
@@ -106,6 +118,16 @@ module function get_gradients(self) result(gradients)
106118
!! Gradients to get
107119
end function get_gradients
108120

121+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
122+
!! Return pointers to the gradients of this layer.
123+
class(conv1d_layer), intent(in), target :: self
124+
!! A `conv1d_layer` instance
125+
real, pointer, intent(out) :: dw_ptr(:)
126+
!! Pointer to the kernel weight gradients (flattened)
127+
real, pointer, intent(out) :: db_ptr(:)
128+
!! Pointer to the bias gradients
129+
end subroutine get_gradients_ptr
130+
109131
module subroutine set_params(self, params)
110132
!! Set the parameters of the layer.
111133
class(conv1d_layer), intent(in out) :: self

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ module function get_params(self) result(params)
152152
params = [ w_, self % biases]
153153
end function get_params
154154

155+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
156+
class(conv1d_layer), intent(in), target :: self
157+
real, pointer, intent(out) :: w_ptr(:)
158+
real, pointer, intent(out) :: b_ptr(:)
159+
w_ptr(1:size(self % kernel)) => self % kernel
160+
b_ptr => self % biases
161+
end subroutine get_params_ptr
162+
155163
module function get_gradients(self) result(gradients)
156164
class(conv1d_layer), intent(in), target :: self
157165
real, allocatable :: gradients(:)
@@ -160,6 +168,14 @@ module function get_gradients(self) result(gradients)
160168
gradients = [ dw_, self % db ]
161169
end function get_gradients
162170

171+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
172+
class(conv1d_layer), intent(in), target :: self
173+
real, pointer, intent(out) :: dw_ptr(:)
174+
real, pointer, intent(out) :: db_ptr(:)
175+
dw_ptr(1:size(self % dw)) => self % dw
176+
db_ptr => self % db
177+
end subroutine get_gradients_ptr
178+
163179
module subroutine set_params(self, params)
164180
class(conv1d_layer), intent(in out) :: self
165181
real, intent(in) :: params(:)

src/nf/nf_conv2d_layer.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ module nf_conv2d_layer
3333
procedure :: forward
3434
procedure :: backward
3535
procedure :: get_gradients
36+
procedure :: get_gradients_ptr
3637
procedure :: get_num_params
3738
procedure :: get_params
39+
procedure :: get_params_ptr
3840
procedure :: init
3941
procedure :: set_params
4042

@@ -98,6 +100,16 @@ module function get_params(self) result(params)
98100
!! Parameters to get
99101
end function get_params
100102

103+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
104+
!! Return pointers to the parameters (weights and biases) of this layer.
105+
class(conv2d_layer), intent(in), target :: self
106+
!! A `conv2d_layer` instance
107+
real, pointer, intent(out) :: w_ptr(:)
108+
!! Pointer to the kernel weights (flattened)
109+
real, pointer, intent(out) :: b_ptr(:)
110+
!! Pointer to the biases
111+
end subroutine get_params_ptr
112+
101113
module function get_gradients(self) result(gradients)
102114
!! Return the gradients of this layer.
103115
!! The gradients are ordered as weights first, biases second.
@@ -107,6 +119,16 @@ module function get_gradients(self) result(gradients)
107119
!! Gradients to get
108120
end function get_gradients
109121

122+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
123+
!! Return pointers to the gradients of this layer.
124+
class(conv2d_layer), intent(in), target :: self
125+
!! A `conv2d_layer` instance
126+
real, pointer, intent(out) :: dw_ptr(:)
127+
!! Pointer to the kernel weight gradients (flattened)
128+
real, pointer, intent(out) :: db_ptr(:)
129+
!! Pointer to the bias gradients
130+
end subroutine get_gradients_ptr
131+
110132
module subroutine set_params(self, params)
111133
!! Set the parameters of the layer.
112134
class(conv2d_layer), intent(in out) :: self

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ module function get_params(self) result(params)
204204

205205
end function get_params
206206

207+
208+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
209+
class(conv2d_layer), intent(in), target :: self
210+
real, pointer, intent(out) :: w_ptr(:)
211+
real, pointer, intent(out) :: b_ptr(:)
212+
w_ptr(1:size(self % kernel)) => self % kernel
213+
b_ptr => self % biases
214+
end subroutine get_params_ptr
215+
207216

208217
module function get_gradients(self) result(gradients)
209218
class(conv2d_layer), intent(in), target :: self
@@ -221,6 +230,15 @@ module function get_gradients(self) result(gradients)
221230
end function get_gradients
222231

223232

233+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
234+
class(conv2d_layer), intent(in), target :: self
235+
real, pointer, intent(out) :: dw_ptr(:)
236+
real, pointer, intent(out) :: db_ptr(:)
237+
dw_ptr(1:size(self % dw)) => self % dw
238+
db_ptr => self % db
239+
end subroutine get_gradients_ptr
240+
241+
224242
module subroutine set_params(self, params)
225243
class(conv2d_layer), intent(in out) :: self
226244
real, intent(in) :: params(:)

src/nf/nf_locally_connected1d_layer.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ module nf_locally_connected1d_layer
3232
procedure :: forward
3333
procedure :: backward
3434
procedure :: get_gradients
35+
procedure :: get_gradients_ptr
3536
procedure :: get_num_params
3637
procedure :: get_params
38+
procedure :: get_params_ptr
3739
procedure :: init
3840
procedure :: set_params
3941

@@ -97,6 +99,12 @@ module function get_params(self) result(params)
9799
!! Parameters to get
98100
end function get_params
99101

102+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
103+
class(locally_connected1d_layer), intent(in), target :: self
104+
real, pointer, intent(out) :: w_ptr(:)
105+
real, pointer, intent(out) :: b_ptr(:)
106+
end subroutine get_params_ptr
107+
100108
module function get_gradients(self) result(gradients)
101109
!! Return the gradients of this layer.
102110
!! The gradients are ordered as weights first, biases second.
@@ -106,6 +114,12 @@ module function get_gradients(self) result(gradients)
106114
!! Gradients to get
107115
end function get_gradients
108116

117+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
118+
class(locally_connected1d_layer), intent(in), target :: self
119+
real, pointer, intent(out) :: dw_ptr(:)
120+
real, pointer, intent(out) :: db_ptr(:)
121+
end subroutine get_gradients_ptr
122+
109123
module subroutine set_params(self, params)
110124
!! Set the parameters of the layer.
111125
class(locally_connected1d_layer), intent(in out) :: self

src/nf/nf_locally_connected1d_layer_submodule.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,28 @@ module function get_params(self) result(params)
128128
params = [self % kernel, self % biases]
129129
end function get_params
130130

131+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
132+
class(locally_connected1d_layer), intent(in), target :: self
133+
real, pointer, intent(out) :: w_ptr(:)
134+
real, pointer, intent(out) :: b_ptr(:)
135+
w_ptr(1:size(self % kernel)) => self % kernel
136+
b_ptr(1:size(self % biases)) => self % biases
137+
end subroutine get_params_ptr
138+
131139
module function get_gradients(self) result(gradients)
132140
class(locally_connected1d_layer), intent(in), target :: self
133141
real, allocatable :: gradients(:)
134142
gradients = [self % dw, self % db]
135143
end function get_gradients
136144

145+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
146+
class(locally_connected1d_layer), intent(in), target :: self
147+
real, pointer, intent(out) :: dw_ptr(:)
148+
real, pointer, intent(out) :: db_ptr(:)
149+
dw_ptr(1:size(self % dw)) => self % dw
150+
db_ptr(1:size(self % db)) => self % db
151+
end subroutine get_gradients_ptr
152+
137153
module subroutine set_params(self, params)
138154
class(locally_connected1d_layer), intent(in out) :: self
139155
real, intent(in) :: params(:)

src/nf/nf_network_submodule.f90

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -701,28 +701,27 @@ module subroutine update(self, optimizer, batch_size)
701701
call this_layer % get_gradients_ptr(dw, db)
702702
call self % optimizer % minimize(weights, dw / batch_size_)
703703
call self % optimizer % minimize(biases, db / batch_size_)
704-
type is(locally_connected1d_layer)
705-
!TODO
706-
type is(conv1d_layer)
707-
!TODO
708-
type is(conv2d_layer)
709-
!TODO
710-
end select
711-
end do
712-
713-
! Flush network gradients to zero.
714-
do n = 2, size(self % layers)
715-
select type(this_layer => self % layers(n) % p)
716-
type is(dense_layer)
717704
this_layer % dw = 0
718705
this_layer % db = 0
719-
type is(conv2d_layer)
706+
type is(conv1d_layer)
707+
call this_layer % get_params_ptr(weights, biases)
708+
call this_layer % get_gradients_ptr(dw, db)
709+
call self % optimizer % minimize(weights, dw / batch_size_)
710+
call self % optimizer % minimize(biases, db / batch_size_)
720711
this_layer % dw = 0
721712
this_layer % db = 0
722-
type is(conv1d_layer)
713+
type is(conv2d_layer)
714+
call this_layer % get_params_ptr(weights, biases)
715+
call this_layer % get_gradients_ptr(dw, db)
716+
call self % optimizer % minimize(weights, dw / batch_size_)
717+
call self % optimizer % minimize(biases, db / batch_size_)
723718
this_layer % dw = 0
724719
this_layer % db = 0
725720
type is(locally_connected1d_layer)
721+
call this_layer % get_params_ptr(weights, biases)
722+
call this_layer % get_gradients_ptr(dw, db)
723+
call self % optimizer % minimize(weights, dw / batch_size_)
724+
call self % optimizer % minimize(biases, db / batch_size_)
726725
this_layer % dw = 0
727726
this_layer % db = 0
728727
end select

0 commit comments

Comments
 (0)