Skip to content

Commit 925e306

Browse files
committed
Allow as 3-d output of networks; assign activation function to conv2d layers from Keras
1 parent 5592b9a commit 925e306

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ module function network_from_keras(filename) result(res)
7979
layers(n) = conv2d( &
8080
keras_layers(n) % filters, &
8181
!FIXME add support for non-square kernel
82-
keras_layers(n) % kernel_size(1) &
82+
keras_layers(n) % kernel_size(1), &
83+
keras_layers(n) % activation &
8384
)
8485

8586
case('Dense')
@@ -162,7 +163,7 @@ module function network_from_keras(filename) result(res)
162163

163164
type is(maxpool2d_layer)
164165
! Nothing to do
165-
continue
166+
continue
166167

167168
class default
168169
error stop 'Internal error in network_from_keras(); ' &
@@ -258,6 +259,8 @@ module function output_1d(self, input) result(res)
258259
res = output_layer % output
259260
type is(flatten_layer)
260261
res = output_layer % output
262+
class default
263+
error stop 'network % output not implemented for this output layer'
261264
end select
262265

263266
end function output_1d
@@ -274,10 +277,15 @@ module function output_3d(self, input) result(res)
274277
call self % forward(input)
275278

276279
select type(output_layer => self % layers(num_layers) % p)
280+
type is(conv2d_layer)
281+
!FIXME flatten the result for now; find a better solution
282+
res = pack(output_layer % output, .true.)
277283
type is(dense_layer)
278284
res = output_layer % output
279285
type is(flatten_layer)
280286
res = output_layer % output
287+
class default
288+
error stop 'network % output not implemented for this output layer'
281289
end select
282290

283291
end function output_3d

0 commit comments

Comments
 (0)