Skip to content

Commit 118f795

Browse files
jvdp1Vandenplas, Jeremiemilancurcic
authored
Intrinsic pack replaced by pointers in get_params and get_gradients (#183)
* Replace intrinsic pack by pointers * Dense layer: remove an avoidable reshape * conv2d: avoid intrinsics pack and reshape * replace a reshape by a pointer * clean conv2d_layer_submodule --------- Co-authored-by: Vandenplas, Jeremie <[email protected]> Co-authored-by: milancurcic <[email protected]>
1 parent a843c83 commit 118f795

8 files changed

+60
-46
lines changed

src/nf/nf_conv2d_layer.f90

+4-4
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,19 @@ pure module function get_num_params(self) result(num_params)
8989
!! Number of parameters
9090
end function get_num_params
9191

92-
pure module function get_params(self) result(params)
92+
module function get_params(self) result(params)
9393
!! Return the parameters (weights and biases) of this layer.
9494
!! The parameters are ordered as weights first, biases second.
95-
class(conv2d_layer), intent(in) :: self
95+
class(conv2d_layer), intent(in), target :: self
9696
!! A `conv2d_layer` instance
9797
real, allocatable :: params(:)
9898
!! Parameters to get
9999
end function get_params
100100

101-
pure module function get_gradients(self) result(gradients)
101+
module function get_gradients(self) result(gradients)
102102
!! Return the gradients of this layer.
103103
!! The gradients are ordered as weights first, biases second.
104-
class(conv2d_layer), intent(in) :: self
104+
class(conv2d_layer), intent(in), target :: self
105105
!! A `conv2d_layer` instance
106106
real, allocatable :: gradients(:)
107107
!! Gradients to get

src/nf/nf_conv2d_layer_submodule.f90

+18-11
Original file line numberDiff line numberDiff line change
@@ -189,24 +189,32 @@ pure module function get_num_params(self) result(num_params)
189189
end function get_num_params
190190

191191

192-
pure module function get_params(self) result(params)
193-
class(conv2d_layer), intent(in) :: self
192+
module function get_params(self) result(params)
193+
class(conv2d_layer), intent(in), target :: self
194194
real, allocatable :: params(:)
195195

196+
real, pointer :: w_(:) => null()
197+
198+
w_(1:size(self % kernel)) => self % kernel
199+
196200
params = [ &
197-
pack(self % kernel, .true.), &
201+
w_, &
198202
self % biases &
199203
]
200204

201205
end function get_params
202206

203207

204-
pure module function get_gradients(self) result(gradients)
205-
class(conv2d_layer), intent(in) :: self
208+
module function get_gradients(self) result(gradients)
209+
class(conv2d_layer), intent(in), target :: self
206210
real, allocatable :: gradients(:)
207211

212+
real, pointer :: dw_(:) => null()
213+
214+
dw_(1:size(self % dw)) => self % dw
215+
208216
gradients = [ &
209-
pack(self % dw, .true.), &
217+
dw_, &
210218
self % db &
211219
]
212220

@@ -219,7 +227,7 @@ module subroutine set_params(self, params)
219227

220228
! Check that the number of parameters is correct.
221229
if (size(params) /= self % get_num_params()) then
222-
error stop 'conv2d % set_params: Number of parameters does not match'
230+
error stop 'conv2d % set_params: Number of parameters does not match'
223231
end if
224232

225233
! Reshape the kernel.
@@ -229,10 +237,9 @@ module subroutine set_params(self, params)
229237
)
230238

231239
! Reshape the biases.
232-
self % biases = reshape( &
233-
params(product(shape(self % kernel)) + 1:), &
234-
[self % filters] &
235-
)
240+
associate(n => product(shape(self % kernel)))
241+
self % biases = params(n + 1 : n + self % filters)
242+
end associate
236243

237244
end subroutine set_params
238245

src/nf/nf_dense_layer.f90

+5-5
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,19 @@ pure module function get_num_params(self) result(num_params)
8787
!! Number of parameters in this layer
8888
end function get_num_params
8989

90-
pure module function get_params(self) result(params)
90+
module function get_params(self) result(params)
9191
!! Return the parameters (weights and biases) of this layer.
9292
!! The parameters are ordered as weights first, biases second.
93-
class(dense_layer), intent(in) :: self
93+
class(dense_layer), intent(in), target :: self
9494
!! Dense layer instance
9595
real, allocatable :: params(:)
9696
!! Parameters of this layer
9797
end function get_params
9898

99-
pure module function get_gradients(self) result(gradients)
99+
module function get_gradients(self) result(gradients)
100100
!! Return the gradients of this layer.
101101
!! The gradients are ordered as weights first, biases second.
102-
class(dense_layer), intent(in) :: self
102+
class(dense_layer), intent(in), target :: self
103103
!! Dense layer instance
104104
real, allocatable :: gradients(:)
105105
!! Gradients of this layer
@@ -110,7 +110,7 @@ module subroutine set_params(self, params)
110110
!! The parameters are ordered as weights first, biases second.
111111
class(dense_layer), intent(in out) :: self
112112
!! Dense layer instance
113-
real, intent(in) :: params(:)
113+
real, intent(in), target :: params(:)
114114
!! Parameters of this layer
115115
end subroutine set_params
116116

src/nf/nf_dense_layer_submodule.f90

+25-18
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,32 @@ pure module function get_num_params(self) result(num_params)
6161
end function get_num_params
6262

6363

64-
pure module function get_params(self) result(params)
65-
class(dense_layer), intent(in) :: self
64+
module function get_params(self) result(params)
65+
class(dense_layer), intent(in), target :: self
6666
real, allocatable :: params(:)
6767

68+
real, pointer :: w_(:) => null()
69+
70+
w_(1:size(self % weights)) => self % weights
71+
6872
params = [ &
69-
pack(self % weights, .true.), &
73+
w_, &
7074
self % biases &
7175
]
7276

7377
end function get_params
7478

7579

76-
pure module function get_gradients(self) result(gradients)
77-
class(dense_layer), intent(in) :: self
80+
module function get_gradients(self) result(gradients)
81+
class(dense_layer), intent(in), target :: self
7882
real, allocatable :: gradients(:)
7983

84+
real, pointer :: dw_(:) => null()
85+
86+
dw_(1:size(self % dw)) => self % dw
87+
8088
gradients = [ &
81-
pack(self % dw, .true.), &
89+
dw_, &
8290
self % db &
8391
]
8492

@@ -87,24 +95,23 @@ end function get_gradients
8795

8896
module subroutine set_params(self, params)
8997
class(dense_layer), intent(in out) :: self
90-
real, intent(in) :: params(:)
98+
real, intent(in), target :: params(:)
99+
100+
real, pointer :: p_(:,:) => null()
91101

92102
! check if the number of parameters is correct
93103
if (size(params) /= self % get_num_params()) then
94104
error stop 'Error: number of parameters does not match'
95105
end if
96106

97-
! reshape the weights
98-
self % weights = reshape( &
99-
params(:self % input_size * self % output_size), &
100-
[self % input_size, self % output_size] &
101-
)
102-
103-
! reshape the biases
104-
self % biases = reshape( &
105-
params(self % input_size * self % output_size + 1:), &
106-
[self % output_size] &
107-
)
107+
associate(n => self % input_size * self % output_size)
108+
! reshape the weights
109+
p_(1:self % input_size, 1:self % output_size) => params(1 : n)
110+
self % weights = p_
111+
112+
! reshape the biases
113+
self % biases = params(n + 1 : n + self % output_size)
114+
end associate
108115

109116
end subroutine set_params
110117

src/nf/nf_layer.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ elemental module function get_num_params(self) result(num_params)
129129
!! Number of parameters in this layer
130130
end function get_num_params
131131

132-
pure module function get_params(self) result(params)
132+
module function get_params(self) result(params)
133133
!! Returns the parameters of this layer.
134134
class(layer), intent(in) :: self
135135
!! Layer instance
136136
real, allocatable :: params(:)
137137
!! Parameters of this layer
138138
end function get_params
139139

140-
pure module function get_gradients(self) result(gradients)
140+
module function get_gradients(self) result(gradients)
141141
!! Returns the gradients of this layer.
142142
class(layer), intent(in) :: self
143143
!! Layer instance

src/nf/nf_layer_submodule.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ elemental module function get_num_params(self) result(num_params)
298298

299299
end function get_num_params
300300

301-
pure module function get_params(self) result(params)
301+
module function get_params(self) result(params)
302302
class(layer), intent(in) :: self
303303
real, allocatable :: params(:)
304304

@@ -323,7 +323,7 @@ pure module function get_params(self) result(params)
323323

324324
end function get_params
325325

326-
pure module function get_gradients(self) result(gradients)
326+
module function get_gradients(self) result(gradients)
327327
class(layer), intent(in) :: self
328328
real, allocatable :: gradients(:)
329329

src/nf/nf_network.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,15 @@ pure module integer function get_num_params(self)
172172
!! Network instance
173173
end function get_num_params
174174

175-
pure module function get_params(self) result(params)
175+
module function get_params(self) result(params)
176176
!! Get the network parameters (weights and biases).
177177
class(network), intent(in) :: self
178178
!! Network instance
179179
real, allocatable :: params(:)
180180
!! Network parameters to get
181181
end function get_params
182182

183-
pure module function get_gradients(self) result(gradients)
183+
module function get_gradients(self) result(gradients)
184184
class(network), intent(in) :: self
185185
!! Network instance
186186
real, allocatable :: gradients(:)

src/nf/nf_network_submodule.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ pure module function get_num_params(self)
526526
end function get_num_params
527527

528528

529-
pure module function get_params(self) result(params)
529+
module function get_params(self) result(params)
530530
class(network), intent(in) :: self
531531
real, allocatable :: params(:)
532532
integer :: n, nstart, nend
@@ -546,7 +546,7 @@ pure module function get_params(self) result(params)
546546
end function get_params
547547

548548

549-
pure module function get_gradients(self) result(gradients)
549+
module function get_gradients(self) result(gradients)
550550
class(network), intent(in) :: self
551551
real, allocatable :: gradients(:)
552552
integer :: n, nstart, nend

0 commit comments

Comments
 (0)