Skip to content

Commit ff2321d

Browse files
committed
Enable Conv2D, Flatten, and MaxPooling2D in the network constructor from Keras
1 parent d943b0c commit ff2321d

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use nf_io_hdf5, only: get_hdf5_dataset
88
use nf_keras, only: get_keras_h5_layers, keras_layer
99
use nf_layer, only: layer
10-
use nf_layer_constructors, only: dense, input
10+
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d
1111
use nf_loss, only: quadratic_derivative
1212
use nf_optimizers, only: sgd
1313
use nf_parallel, only: tile_indices
@@ -68,6 +68,27 @@ module function network_from_keras(filename) result(res)
6868

6969
select case(keras_layers(n) % class)
7070

71+
case('Conv2D')
72+
73+
if (keras_layers(n) % kernel_size(1) &
74+
/= keras_layers(n) % kernel_size(2)) &
75+
error stop 'Non-square kernel in conv2d layer not supported.'
76+
77+
layers(n) = conv2d( &
78+
keras_layers(n) % filters, &
79+
!FIXME add support for non-square kernel
80+
keras_layers(n) % kernel_size(1) &
81+
)
82+
83+
case('Dense')
84+
layers(n) = dense( &
85+
keras_layers(n) % units(1), &
86+
keras_layers(n) % activation &
87+
)
88+
89+
case('Flatten')
90+
layers(n) = flatten()
91+
7192
case('InputLayer')
7293
if (size(keras_layers(n) % units) == 1) then
7394
! input1d
@@ -77,10 +98,20 @@ module function network_from_keras(filename) result(res)
7798
layers(n) = input(keras_layers(n) % units)
7899
end if
79100

80-
case('Dense')
81-
layers(n) = dense( &
82-
keras_layers(n) % units(1), &
83-
keras_layers(n) % activation &
101+
case('MaxPooling2D')
102+
103+
if (keras_layers(n) % pool_size(1) &
104+
/= keras_layers(n) % pool_size(2)) &
105+
error stop 'Non-square pool in maxpool2d layer not supported.'
106+
107+
if (keras_layers(n) % strides(1) &
108+
/= keras_layers(n) % strides(2)) &
109+
error stop 'Unequal strides in maxpool2d layer are not supported.'
110+
111+
layers(n) = maxpool2d( &
112+
!FIXME add support for non-square pool and stride
113+
keras_layers(n) % pool_size(1), &
114+
keras_layers(n) % strides(1) &
84115
)
85116

86117
case default

0 commit comments

Comments
 (0)