Skip to content

Commit 1c968ce

Browse files
authored
Apply optimizer to model weights without data copy (#222)
* WIP optimizer refactor w/ pointers * WIP optimizer optimization * Send the data to optimizer without a copy works for dense layers * Get weights and weight gradients as 1d * get_params_ptr and get_gradients_ptr for conv1d, conv2d, and locally_connected1d * Define optimizer instance per layer to preserve memory across layers * Initialization of network-wide optimizer no longer needed now that we switched to per-layer optimizer instances * Bookkeeping for velocity, rms_gradient, etc.; optimizer tests now pass * Update optimizer flow for linear2d * Update optimizer flow for layernorm * Previous bookkeeping for successive calls to optim % minimize() assumed 2 calls per batch; this is now generalized to allow any number of calls until size(params) is exhausted * Remove get_gradients from network, layer, dense, conv1d, conv2d * Remove optimizer as component to the network class
1 parent 402b84a commit 1c968ce

18 files changed

+325
-213
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ module nf_conv1d_layer
3131

3232
procedure :: forward
3333
procedure :: backward
34-
procedure :: get_gradients
34+
procedure :: get_gradients_ptr
3535
procedure :: get_num_params
3636
procedure :: get_params
37+
procedure :: get_params_ptr
3738
procedure :: init
3839
procedure :: set_params
3940

@@ -97,14 +98,25 @@ module function get_params(self) result(params)
9798
!! Parameters to get
9899
end function get_params
99100

100-
module function get_gradients(self) result(gradients)
101-
!! Return the gradients of this layer.
102-
!! The gradients are ordered as weights first, biases second.
101+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
102+
!! Return pointers to the parameters (weights and biases) of this layer.
103103
class(conv1d_layer), intent(in), target :: self
104104
!! A `conv1d_layer` instance
105-
real, allocatable :: gradients(:)
106-
!! Gradients to get
107-
end function get_gradients
105+
real, pointer, intent(out) :: w_ptr(:)
106+
!! Pointer to the kernel weights (flattened)
107+
real, pointer, intent(out) :: b_ptr(:)
108+
!! Pointer to the biases
109+
end subroutine get_params_ptr
110+
111+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
112+
!! Return pointers to the gradients of this layer.
113+
class(conv1d_layer), intent(in), target :: self
114+
!! A `conv1d_layer` instance
115+
real, pointer, intent(out) :: dw_ptr(:)
116+
!! Pointer to the kernel weight gradients (flattened)
117+
real, pointer, intent(out) :: db_ptr(:)
118+
!! Pointer to the bias gradients
119+
end subroutine get_gradients_ptr
108120

109121
module subroutine set_params(self, params)
110122
!! Set the parameters of the layer.

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,21 @@ module function get_params(self) result(params)
152152
params = [ w_, self % biases]
153153
end function get_params
154154

155-
module function get_gradients(self) result(gradients)
155+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
156156
class(conv1d_layer), intent(in), target :: self
157-
real, allocatable :: gradients(:)
158-
real, pointer :: dw_(:) => null()
159-
dw_(1:size(self % dw)) => self % dw
160-
gradients = [ dw_, self % db ]
161-
end function get_gradients
157+
real, pointer, intent(out) :: w_ptr(:)
158+
real, pointer, intent(out) :: b_ptr(:)
159+
w_ptr(1:size(self % kernel)) => self % kernel
160+
b_ptr => self % biases
161+
end subroutine get_params_ptr
162+
163+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
164+
class(conv1d_layer), intent(in), target :: self
165+
real, pointer, intent(out) :: dw_ptr(:)
166+
real, pointer, intent(out) :: db_ptr(:)
167+
dw_ptr(1:size(self % dw)) => self % dw
168+
db_ptr => self % db
169+
end subroutine get_gradients_ptr
162170

163171
module subroutine set_params(self, params)
164172
class(conv1d_layer), intent(in out) :: self

src/nf/nf_conv2d_layer.f90

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ module nf_conv2d_layer
3232

3333
procedure :: forward
3434
procedure :: backward
35-
procedure :: get_gradients
35+
procedure :: get_gradients_ptr
3636
procedure :: get_num_params
3737
procedure :: get_params
38+
procedure :: get_params_ptr
3839
procedure :: init
3940
procedure :: set_params
4041

@@ -98,14 +99,25 @@ module function get_params(self) result(params)
9899
!! Parameters to get
99100
end function get_params
100101

101-
module function get_gradients(self) result(gradients)
102-
!! Return the gradients of this layer.
103-
!! The gradients are ordered as weights first, biases second.
102+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
103+
!! Return pointers to the parameters (weights and biases) of this layer.
104104
class(conv2d_layer), intent(in), target :: self
105105
!! A `conv2d_layer` instance
106-
real, allocatable :: gradients(:)
107-
!! Gradients to get
108-
end function get_gradients
106+
real, pointer, intent(out) :: w_ptr(:)
107+
!! Pointer to the kernel weights (flattened)
108+
real, pointer, intent(out) :: b_ptr(:)
109+
!! Pointer to the biases
110+
end subroutine get_params_ptr
111+
112+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
113+
!! Return pointers to the gradients of this layer.
114+
class(conv2d_layer), intent(in), target :: self
115+
!! A `conv2d_layer` instance
116+
real, pointer, intent(out) :: dw_ptr(:)
117+
!! Pointer to the kernel weight gradients (flattened)
118+
real, pointer, intent(out) :: db_ptr(:)
119+
!! Pointer to the bias gradients
120+
end subroutine get_gradients_ptr
109121

110122
module subroutine set_params(self, params)
111123
!! Set the parameters of the layer.

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,21 +204,23 @@ module function get_params(self) result(params)
204204

205205
end function get_params
206206

207-
208-
module function get_gradients(self) result(gradients)
207+
208+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
209209
class(conv2d_layer), intent(in), target :: self
210-
real, allocatable :: gradients(:)
211-
212-
real, pointer :: dw_(:) => null()
210+
real, pointer, intent(out) :: w_ptr(:)
211+
real, pointer, intent(out) :: b_ptr(:)
212+
w_ptr(1:size(self % kernel)) => self % kernel
213+
b_ptr => self % biases
214+
end subroutine get_params_ptr
213215

214-
dw_(1:size(self % dw)) => self % dw
215216

216-
gradients = [ &
217-
dw_, &
218-
self % db &
219-
]
220-
221-
end function get_gradients
217+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
218+
class(conv2d_layer), intent(in), target :: self
219+
real, pointer, intent(out) :: dw_ptr(:)
220+
real, pointer, intent(out) :: db_ptr(:)
221+
dw_ptr(1:size(self % dw)) => self % dw
222+
db_ptr => self % db
223+
end subroutine get_gradients_ptr
222224

223225

224226
module subroutine set_params(self, params)

src/nf/nf_dense_layer.f90

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ module nf_dense_layer
3333

3434
procedure :: backward
3535
procedure :: forward
36-
procedure :: get_gradients
36+
procedure :: get_gradients_ptr
3737
procedure :: get_num_params
3838
procedure :: get_params
39+
procedure :: get_params_ptr
3940
procedure :: init
4041
procedure :: set_params
4142

@@ -96,14 +97,17 @@ module function get_params(self) result(params)
9697
!! Parameters of this layer
9798
end function get_params
9899

99-
module function get_gradients(self) result(gradients)
100-
!! Return the gradients of this layer.
101-
!! The gradients are ordered as weights first, biases second.
100+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
102101
class(dense_layer), intent(in), target :: self
103-
!! Dense layer instance
104-
real, allocatable :: gradients(:)
105-
!! Gradients of this layer
106-
end function get_gradients
102+
real, pointer, intent(out) :: w_ptr(:)
103+
real, pointer, intent(out) :: b_ptr(:)
104+
end subroutine get_params_ptr
105+
106+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
107+
class(dense_layer), intent(in), target :: self
108+
real, pointer, intent(out) :: dw_ptr(:)
109+
real, pointer, intent(out) :: db_ptr(:)
110+
end subroutine get_gradients_ptr
107111

108112
module subroutine set_params(self, params)
109113
!! Set the parameters of this layer.

src/nf/nf_dense_layer_submodule.f90

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,22 @@ module function get_params(self) result(params)
7777
end function get_params
7878

7979

80-
module function get_gradients(self) result(gradients)
80+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
8181
class(dense_layer), intent(in), target :: self
82-
real, allocatable :: gradients(:)
82+
real, pointer, intent(out) :: w_ptr(:)
83+
real, pointer, intent(out) :: b_ptr(:)
84+
w_ptr(1:size(self % weights)) => self % weights
85+
b_ptr => self % biases
86+
end subroutine get_params_ptr
8387

84-
real, pointer :: dw_(:) => null()
8588

86-
dw_(1:size(self % dw)) => self % dw
87-
88-
gradients = [ &
89-
dw_, &
90-
self % db &
91-
]
92-
93-
end function get_gradients
89+
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
90+
class(dense_layer), intent(in), target :: self
91+
real, pointer, intent(out) :: dw_ptr(:)
92+
real, pointer, intent(out) :: db_ptr(:)
93+
dw_ptr(1:size(self % dw)) => self % dw
94+
db_ptr => self % db
95+
end subroutine get_gradients_ptr
9496

9597

9698
module subroutine set_params(self, params)

src/nf/nf_layer.f90

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ module nf_layer
2222
integer, allocatable :: layer_shape(:)
2323
integer, allocatable :: input_layer_shape(:)
2424
logical :: initialized = .false.
25+
class(optimizer_base_type), allocatable :: optimizer
2526

2627
contains
2728

2829
procedure :: forward
2930
procedure :: get_num_params
3031
procedure :: get_params
31-
procedure :: get_gradients
3232
procedure :: set_params
3333
procedure :: init
3434
procedure :: print_info
@@ -160,14 +160,6 @@ module function get_params(self) result(params)
160160
!! Parameters of this layer
161161
end function get_params
162162

163-
module function get_gradients(self) result(gradients)
164-
!! Returns the gradients of this layer.
165-
class(layer), intent(in) :: self
166-
!! Layer instance
167-
real, allocatable :: gradients(:)
168-
!! Gradients of this layer
169-
end function get_gradients
170-
171163
module subroutine set_params(self, params)
172164
!! Returns the parameters of this layer.
173165
class(layer), intent(in out) :: self

src/nf/nf_layer_submodule.f90

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -682,50 +682,6 @@ module function get_params(self) result(params)
682682

683683
end function get_params
684684

685-
module function get_gradients(self) result(gradients)
686-
class(layer), intent(in) :: self
687-
real, allocatable :: gradients(:)
688-
689-
select type (this_layer => self % p)
690-
type is (input1d_layer)
691-
! No gradients to get.
692-
type is (input2d_layer)
693-
! No gradients to get.
694-
type is (input3d_layer)
695-
! No gradients to get.
696-
type is (dense_layer)
697-
gradients = this_layer % get_gradients()
698-
type is (dropout_layer)
699-
! No gradients to get.
700-
type is (conv1d_layer)
701-
gradients = this_layer % get_gradients()
702-
type is (conv2d_layer)
703-
gradients = this_layer % get_gradients()
704-
type is (locally_connected1d_layer)
705-
gradients = this_layer % get_gradients()
706-
type is (maxpool1d_layer)
707-
! No gradients to get.
708-
type is (maxpool2d_layer)
709-
! No gradients to get.
710-
type is (flatten_layer)
711-
! No gradients to get.
712-
type is (reshape2d_layer)
713-
! No parameters to get.
714-
type is (reshape3d_layer)
715-
! No gradients to get.
716-
type is (linear2d_layer)
717-
gradients = this_layer % get_gradients()
718-
type is (self_attention_layer)
719-
gradients = this_layer % get_gradients()
720-
type is (embedding_layer)
721-
gradients = this_layer % get_gradients()
722-
type is (layernorm_layer)
723-
gradients = this_layer % get_gradients()
724-
class default
725-
error stop 'Unknown layer type.'
726-
end select
727-
728-
end function get_gradients
729685

730686
module subroutine set_params(self, params)
731687
class(layer), intent(in out) :: self

src/nf/nf_layernorm.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ module nf_layernorm_layer
3838
procedure :: init
3939
procedure :: get_num_params
4040
procedure :: get_params
41+
procedure :: get_params_ptr
4142
procedure :: get_gradients
43+
procedure :: get_gradients_ptr
4244
procedure :: set_params
4345
end type layernorm_layer
4446

@@ -78,12 +80,24 @@ module function get_params(self) result(params)
7880
end function get_params
7981

8082

83+
module subroutine get_params_ptr(self, g_ptr, b_ptr)
84+
class(layernorm_layer), intent(in), target :: self
85+
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
86+
end subroutine get_params_ptr
87+
88+
8189
module function get_gradients(self) result(gradients)
8290
class(layernorm_layer), intent(in), target :: self
8391
real, allocatable :: gradients(:)
8492
end function get_gradients
8593

8694

95+
module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
96+
class(layernorm_layer), intent(in), target :: self
97+
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
98+
end subroutine get_gradients_ptr
99+
100+
87101
module subroutine set_params(self, params)
88102
class(layernorm_layer), intent(in out) :: self
89103
real, intent(in), target :: params(:)

src/nf/nf_layernorm_submodule.f90

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,31 @@ end function get_num_params
112112
module function get_params(self) result(params)
113113
class(layernorm_layer), intent(in), target :: self
114114
real, allocatable :: params(:)
115+
params = [self % gamma, self % beta]
116+
end function get_params
115117

116-
params = [ &
117-
self % gamma, &
118-
self % beta &
119-
]
120118

121-
end function get_params
119+
module subroutine get_params_ptr(self, g_ptr, b_ptr)
120+
class(layernorm_layer), intent(in), target :: self
121+
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
122+
g_ptr => self % gamma
123+
b_ptr => self % beta
124+
end subroutine get_params_ptr
122125

123126

124127
module function get_gradients(self) result(gradients)
125128
class(layernorm_layer), intent(in), target :: self
126129
real, allocatable :: gradients(:)
130+
gradients = [self % d_gamma, self % d_beta]
131+
end function get_gradients
127132

128-
gradients = [ &
129-
self % d_gamma, &
130-
self % d_beta &
131-
]
132133

133-
end function get_gradients
134+
module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
135+
class(layernorm_layer), intent(in), target :: self
136+
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
137+
dg_ptr => self % d_gamma
138+
db_ptr => self % d_beta
139+
end subroutine get_gradients_ptr
134140

135141

136142
module subroutine set_params(self, params)

0 commit comments

Comments
 (0)