Skip to content

Commit 98d12b8

Browse files
committed
feat: Implementing reset state for RNN
1 parent 73487bd commit 98d12b8

6 files changed

+53
-0
lines changed

src/nf/nf_layer.f90

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ module nf_layer
3232
procedure :: set_params
3333
procedure :: init
3434
procedure :: print_info
35+
procedure :: reset
3536

3637
! Specific subroutines for different array ranks
3738
procedure, private :: backward_1d
@@ -153,6 +154,10 @@ module subroutine set_params(self, params)
153154
!! Parameters of this layer
154155
end subroutine set_params
155156

157+
module subroutine reset(self)
158+
class(layer), intent(in out) :: self
159+
end subroutine reset
160+
156161
end interface
157162

158163
end module nf_layer

src/nf/nf_layer_submodule.f90

+10
Original file line numberDiff line numberDiff line change
@@ -442,4 +442,14 @@ module subroutine set_params(self, params)
442442

443443
end subroutine set_params
444444

445+
module subroutine reset(self)
446+
class(layer), intent(in out) :: self
447+
448+
select type (this_layer => self % p)
449+
type is (rnn_layer)
450+
call this_layer % reset()
451+
end select
452+
453+
end subroutine reset
454+
445455
end submodule nf_layer_submodule

src/nf/nf_network.f90

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ module nf_network
2323
procedure :: get_params
2424
procedure :: print_info
2525
procedure :: set_params
26+
procedure :: reset
2627
procedure :: train
2728
procedure :: update
2829

@@ -223,6 +224,13 @@ module subroutine update(self, optimizer, batch_size)
223224
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
224225
end subroutine update
225226

227+
module subroutine reset(self)
228+
!! Reset network state
229+
!!
230+
!! Currently only affect RNN layer type
231+
class(network), intent(in out) :: self
232+
end subroutine reset
233+
226234
end interface
227235

228236
end module nf_network

src/nf/nf_network_submodule.f90

+14
Original file line numberDiff line numberDiff line change
@@ -681,4 +681,18 @@ module subroutine update(self, optimizer, batch_size)
681681

682682
end subroutine update
683683

684+
module subroutine reset(self)
685+
class(network), intent(in out) :: self
686+
integer :: n, num_layers
687+
688+
num_layers = size(self % layers)
689+
do n = 2, num_layers
690+
select type(this_layer => self % layers(n) % p)
691+
type is(rnn_layer)
692+
call self % layers(n) % reset()
693+
end select
694+
end do
695+
696+
end subroutine reset
697+
684698
end submodule nf_network_submodule

src/nf/nf_rnn_layer.f90

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ module nf_rnn_layer
4040
procedure :: get_params
4141
procedure :: init
4242
procedure :: set_params
43+
procedure :: reset
4344

4445
end type rnn_layer
4546

@@ -126,6 +127,14 @@ module subroutine init(self, input_shape)
126127
!! Shape of the input layer
127128
end subroutine init
128129

130+
module subroutine reset(self)
131+
!! Reset layer state
132+
!!
133+
!! Currently reset state to zero but might be worth reconsidering it
134+
!! in the future.
135+
class(rnn_layer), intent(in out) :: self
136+
end subroutine reset
137+
129138
end interface
130139

131140
end module nf_rnn_layer

src/nf/nf_rnn_layer_submodule.f90

+7
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,11 @@ module subroutine init(self, input_shape)
166166

167167
end subroutine init
168168

169+
module subroutine reset(self)
170+
class(rnn_layer), intent(in out) :: self
171+
172+
self % state = 0
173+
174+
end subroutine reset
175+
169176
end submodule nf_rnn_layer_submodule

0 commit comments

Comments
 (0)