Skip to content

Commit fe02beb

Browse files
committed
embedding_layer: make integer input generics
1 parent 73799bd commit fe02beb

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

src/nf/nf_network.f90

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ module nf_network
3232

3333
procedure, private :: evaluate_batch_1d
3434
procedure, private :: forward_1d
35+
procedure, private :: forward_1d_int
3536
procedure, private :: forward_2d
3637
procedure, private :: forward_3d
3738
procedure, private :: predict_1d
39+
procedure, private :: predict_1d_int
3840
procedure, private :: predict_2d
3941
procedure, private :: predict_3d
4042
procedure, private :: predict_batch_1d
4143
procedure, private :: predict_batch_3d
4244

4345
generic :: evaluate => evaluate_batch_1d
44-
generic :: forward => forward_1d, forward_2d, forward_3d
45-
generic :: predict => predict_1d, predict_2d, predict_3d
46+
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
47+
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
4648
generic :: predict_batch => predict_batch_1d, predict_batch_3d
4749

4850
end type network
@@ -95,6 +97,12 @@ module subroutine forward_1d(self, input)
9597
!! 1-d input data
9698
end subroutine forward_1d
9799

100+
module subroutine forward_1d_int(self, input)
101+
!! Same as `forward_1d` except `integer`
102+
class(network), intent(in out) :: self
103+
integer, intent(in) :: input(:)
104+
end subroutine forward_1d_int
105+
98106
module subroutine forward_2d(self, input)
99107
!! Apply a forward pass through the network.
100108
!!
@@ -137,6 +145,13 @@ module function predict_1d(self, input) result(res)
137145
!! Output of the network
138146
end function predict_1d
139147

148+
module function predict_1d_int(self, input) result(res)
149+
!! Same as `predict_1d` except `integer`
150+
class(network), intent(in out) :: self
151+
integer, intent(in) :: input(:)
152+
real, allocatable :: res(:)
153+
end function predict_1d_int
154+
140155
module function predict_2d(self, input) result(res)
141156
!! Return the output of the network given the input 1-d array.
142157
class(network), intent(in out) :: self

src/nf/nf_network_submodule.f90

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ module subroutine forward_1d(self, input)
211211
select type(input_layer => self % layers(1) % p)
212212
type is(input1d_layer)
213213
call input_layer % set(input)
214-
type is(embedding_layer)
215-
call input_layer % forward(nint(input))
216214
end select
217215

218216
do n = 2, size(self % layers)
@@ -221,6 +219,21 @@ module subroutine forward_1d(self, input)
221219

222220
end subroutine forward_1d
223221

222+
module subroutine forward_1d_int(self, input)
223+
class(network), intent(in out) :: self
224+
integer, intent(in) :: input(:)
225+
integer :: n
226+
227+
select type(input_layer => self % layers(1) % p)
228+
type is(embedding_layer)
229+
call input_layer % forward(input)
230+
end select
231+
232+
do n = 2, size(self % layers)
233+
call self % layers(n) % forward(self % layers(n - 1))
234+
end do
235+
236+
end subroutine forward_1d_int
224237

225238
module subroutine forward_2d(self, input)
226239
class(network), intent(in out) :: self
@@ -285,6 +298,31 @@ module function predict_1d(self, input) result(res)
285298

286299
end function predict_1d
287300

301+
module function predict_1d_int(self, input) result(res)
302+
class(network), intent(in out) :: self
303+
integer, intent(in) :: input(:)
304+
real, allocatable :: res(:)
305+
integer :: n, num_layers
306+
307+
num_layers = size(self % layers)
308+
309+
call self % set_training_mode(.false.)
310+
call self % forward(input)
311+
call self % set_training_mode(.true.)
312+
313+
select type(output_layer => self % layers(num_layers) % p)
314+
type is(dense_layer)
315+
res = output_layer % output
316+
type is(dropout_layer)
317+
res = output_layer % output
318+
type is(flatten_layer)
319+
res = output_layer % output
320+
class default
321+
error stop 'network % output not implemented for ' // &
322+
trim(self % layers(num_layers) % name) // ' layer'
323+
end select
324+
325+
end function predict_1d_int
288326

289327
module function predict_2d(self, input) result(res)
290328
class(network), intent(in out) :: self

0 commit comments

Comments
 (0)