Skip to content

Commit 226b030

Browse files
Riccardo OrsiRiccardo Orsi
authored andcommitted
Removed set_params as well
1 parent d36edc7 commit 226b030

15 files changed

+118
-251
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ module nf_conv1d_layer
3535
procedure :: get_num_params
3636
procedure :: get_params_ptr
3737
procedure :: init
38-
procedure :: set_params
3938

4039
end type conv1d_layer
4140

@@ -108,14 +107,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
108107
!! Pointer to the bias gradients
109108
end subroutine get_gradients_ptr
110109

111-
module subroutine set_params(self, params)
112-
!! Set the parameters of the layer.
113-
class(conv1d_layer), intent(in out) :: self
114-
!! A `conv1d_layer` instance
115-
real, intent(in) :: params(:)
116-
!! Parameters to set
117-
end subroutine set_params
118-
119110
end interface
120111

121112
end module nf_conv1d_layer

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
160160
db_ptr => self % db
161161
end subroutine get_gradients_ptr
162162

163-
module subroutine set_params(self, params)
164-
class(conv1d_layer), intent(in out) :: self
165-
real, intent(in) :: params(:)
166-
167-
if (size(params) /= self % get_num_params()) then
168-
error stop 'conv1d_layer % set_params: Number of parameters does not match'
169-
end if
170-
171-
self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel))
172-
associate(n => product(shape(self % kernel)))
173-
self % biases = params(n + 1 : n + self % filters)
174-
end associate
175-
176-
end subroutine set_params
177-
178163
end submodule nf_conv1d_layer_submodule

src/nf/nf_conv2d_layer.f90

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ module nf_conv2d_layer
3636
procedure :: get_num_params
3737
procedure :: get_params_ptr
3838
procedure :: init
39-
procedure :: set_params
4039

4140
end type conv2d_layer
4241

@@ -109,14 +108,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
109108
!! Pointer to the bias gradients
110109
end subroutine get_gradients_ptr
111110

112-
module subroutine set_params(self, params)
113-
!! Set the parameters of the layer.
114-
class(conv2d_layer), intent(in out) :: self
115-
!! A `conv2d_layer` instance
116-
real, intent(in) :: params(:)
117-
!! Parameters to set
118-
end subroutine set_params
119-
120111
end interface
121112

122113
end module nf_conv2d_layer

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -206,27 +206,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
206206
db_ptr => self % db
207207
end subroutine get_gradients_ptr
208208

209-
210-
module subroutine set_params(self, params)
211-
class(conv2d_layer), intent(in out) :: self
212-
real, intent(in) :: params(:)
213-
214-
! Check that the number of parameters is correct.
215-
if (size(params) /= self % get_num_params()) then
216-
error stop 'conv2d % set_params: Number of parameters does not match'
217-
end if
218-
219-
! Reshape the kernel.
220-
self % kernel = reshape( &
221-
params(:product(shape(self % kernel))), &
222-
shape(self % kernel) &
223-
)
224-
225-
! Reshape the biases.
226-
associate(n => product(shape(self % kernel)))
227-
self % biases = params(n + 1 : n + self % filters)
228-
end associate
229-
230-
end subroutine set_params
231-
232209
end submodule nf_conv2d_layer_submodule

src/nf/nf_dense_layer.f90

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ module nf_dense_layer
3737
procedure :: get_num_params
3838
procedure :: get_params_ptr
3939
procedure :: init
40-
procedure :: set_params
4140

4241
end type dense_layer
4342

src/nf/nf_dense_layer_submodule.f90

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,30 +78,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
7878
db_ptr => self % db
7979
end subroutine get_gradients_ptr
8080

81-
82-
module subroutine set_params(self, params)
83-
class(dense_layer), intent(in out) :: self
84-
real, intent(in), target :: params(:)
85-
86-
real, pointer :: p_(:,:) => null()
87-
88-
! check if the number of parameters is correct
89-
if (size(params) /= self % get_num_params()) then
90-
error stop 'Error: number of parameters does not match'
91-
end if
92-
93-
associate(n => self % input_size * self % output_size)
94-
! reshape the weights
95-
p_(1:self % input_size, 1:self % output_size) => params(1 : n)
96-
self % weights = p_
97-
98-
! reshape the biases
99-
self % biases = params(n + 1 : n + self % output_size)
100-
end associate
101-
102-
end subroutine set_params
103-
104-
10581
module subroutine init(self, input_shape)
10682
class(dense_layer), intent(in out) :: self
10783
integer, intent(in) :: input_shape(:)

src/nf/nf_layer_submodule.f90

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,12 @@ module function get_params(self) result(params)
691691
params = this_layer % get_params()
692692
type is (embedding_layer)
693693
params = this_layer % get_params()
694+
694695
type is (layernorm_layer)
695-
params = this_layer % get_params()
696+
call this_layer % get_params_ptr(w_ptr, b_ptr)
697+
allocate(params(size(w_ptr) + size(b_ptr)))
698+
params(1:size(w_ptr)) = w_ptr
699+
params(size(w_ptr)+1:) = b_ptr
696700
class default
697701
error stop 'Unknown layer type.'
698702
end select
@@ -703,6 +707,8 @@ end function get_params
703707
module subroutine set_params(self, params)
704708
class(layer), intent(in out) :: self
705709
real, intent(in) :: params(:)
710+
real, pointer :: w_ptr(:)
711+
real, pointer :: b_ptr(:)
706712

707713
! Check that the number of parameters is correct.
708714
! This check will still pass if the size(params) == 0 and the layer is a
@@ -736,37 +742,55 @@ module subroutine set_params(self, params)
736742
// 'on a zero-parameter layer; nothing to do.'
737743

738744
type is (dense_layer)
739-
call this_layer % set_params(params)
745+
call this_layer % get_params_ptr(w_ptr, b_ptr)
746+
747+
w_ptr = params(1:size(w_ptr))
748+
b_ptr = params(size(w_ptr)+1:)
740749

741750
type is (dropout_layer)
742751
! No parameters to set.
743752
write(stderr, '(a)') 'Warning: calling set_params() ' &
744753
// 'on a zero-parameter layer; nothing to do.'
745754

746755
type is (conv1d_layer)
747-
call this_layer % set_params(params)
756+
call this_layer % get_params_ptr(w_ptr, b_ptr)
757+
758+
w_ptr = params(1:size(w_ptr))
759+
b_ptr = params(size(w_ptr)+1:)
748760

749761
type is (conv2d_layer)
750-
call this_layer % set_params(params)
762+
call this_layer % get_params_ptr(w_ptr, b_ptr)
763+
764+
w_ptr = params(1:size(w_ptr))
765+
b_ptr = params(size(w_ptr)+1:)
751766

752767
type is (locally_connected2d_layer)
753-
call this_layer % set_params(params)
768+
call this_layer % get_params_ptr(w_ptr, b_ptr)
769+
770+
w_ptr = params(1:size(w_ptr))
771+
b_ptr = params(size(w_ptr)+1:)
754772

755773
type is (maxpool1d_layer)
756774
! No parameters to set.
757775
write(stderr, '(a)') 'Warning: calling set_params() ' &
758776
// 'on a zero-parameter layer; nothing to do.'
759777

760778
type is (linear2d_layer)
761-
call this_layer % set_params(params)
779+
call this_layer % get_params_ptr(w_ptr, b_ptr)
780+
781+
w_ptr = params(1:size(w_ptr))
782+
b_ptr = params(size(w_ptr)+1:)
762783

763784
type is (self_attention_layer)
764785
call this_layer % set_params(params)
765786
type is (embedding_layer)
766787
call this_layer % set_params(params)
767788

768789
type is (layernorm_layer)
769-
call this_layer % set_params(params)
790+
call this_layer % get_params_ptr(w_ptr, b_ptr)
791+
792+
w_ptr = params(1:size(w_ptr))
793+
b_ptr = params(size(w_ptr)+1:)
770794

771795
type is (maxpool2d_layer)
772796
! No parameters to set.

src/nf/nf_layernorm.f90

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ module nf_layernorm_layer
3737
procedure :: backward
3838
procedure :: init
3939
procedure :: get_num_params
40-
procedure :: get_params
4140
procedure :: get_params_ptr
4241
procedure :: get_gradients
4342
procedure :: get_gradients_ptr
44-
procedure :: set_params
4543
end type layernorm_layer
4644

4745
interface layernorm_layer
@@ -74,12 +72,6 @@ pure module function get_num_params(self) result(num_params)
7472
end function get_num_params
7573

7674

77-
module function get_params(self) result(params)
78-
class(layernorm_layer), intent(in), target :: self
79-
real, allocatable :: params(:)
80-
end function get_params
81-
82-
8375
module subroutine get_params_ptr(self, g_ptr, b_ptr)
8476
class(layernorm_layer), intent(in), target :: self
8577
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
@@ -98,9 +90,5 @@ module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
9890
end subroutine get_gradients_ptr
9991

10092

101-
module subroutine set_params(self, params)
102-
class(layernorm_layer), intent(in out) :: self
103-
real, intent(in), target :: params(:)
104-
end subroutine set_params
10593
end interface
10694
end module nf_layernorm_layer

src/nf/nf_layernorm_submodule.f90

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,6 @@ pure module function get_num_params(self) result(num_params)
108108

109109
end function get_num_params
110110

111-
112-
module function get_params(self) result(params)
113-
class(layernorm_layer), intent(in), target :: self
114-
real, allocatable :: params(:)
115-
params = [self % gamma, self % beta]
116-
end function get_params
117-
118-
119111
module subroutine get_params_ptr(self, g_ptr, b_ptr)
120112
class(layernorm_layer), intent(in), target :: self
121113
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
@@ -137,19 +129,5 @@ module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
137129
dg_ptr => self % d_gamma
138130
db_ptr => self % d_beta
139131
end subroutine get_gradients_ptr
140-
141-
142-
module subroutine set_params(self, params)
143-
class(layernorm_layer), intent(in out) :: self
144-
real, intent(in), target :: params(:)
145-
146-
! check if the number of parameters is correct
147-
if (size(params) /= self % get_num_params()) then
148-
error stop 'Error: number of parameters does not match'
149-
end if
150-
151-
self % gamma = params(1: self % model_dimension)
152-
self % beta = params(self % model_dimension + 1: 2 * self % model_dimension)
153-
154-
end subroutine set_params
132+
155133
end submodule nf_layernorm_layer_submodule

src/nf/nf_linear2d_layer.f90

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ module nf_linear2d_layer
2424
procedure :: forward
2525
procedure :: init
2626
procedure :: get_num_params
27-
procedure :: get_params
2827
procedure :: get_params_ptr
2928
procedure :: get_gradients
3029
procedure :: get_gradients_ptr
31-
procedure :: set_params
3230

3331
end type linear2d_layer
3432

@@ -61,11 +59,6 @@ pure module function get_num_params(self) result(num_params)
6159
integer :: num_params
6260
end function get_num_params
6361

64-
module function get_params(self) result(params)
65-
class(linear2d_layer), intent(in), target :: self
66-
real, allocatable :: params(:)
67-
end function get_params
68-
6962
module subroutine get_params_ptr(self, w_ptr, b_ptr)
7063
class(linear2d_layer), intent(in), target :: self
7164
real, pointer, intent(out) :: w_ptr(:), b_ptr(:)
@@ -81,9 +74,5 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
8174
real, pointer, intent(out) :: dw_ptr(:), db_ptr(:)
8275
end subroutine get_gradients_ptr
8376

84-
module subroutine set_params(self, params)
85-
class(linear2d_layer), intent(in out) :: self
86-
real, intent(in), target :: params(:)
87-
end subroutine set_params
8877
end interface
8978
end module nf_linear2d_layer

0 commit comments

Comments
 (0)