Skip to content

Commit eb70deb

Browse files
committed
Read kernel and bias data for Conv2D layer from Keras; tranpose values in the HDF5 reader
1 parent 87ebc11 commit eb70deb

File tree

4 files changed

+124
-6
lines changed

4 files changed

+124
-6
lines changed

src/nf/io/nf_io_hdf5.f90

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ module subroutine get_hdf5_dataset_real32_2d(filename, object_name, values)
4545
!! Array to store the dataset values into
4646
end subroutine get_hdf5_dataset_real32_2d
4747

48+
module subroutine get_hdf5_dataset_real32_4d(filename, object_name, values)
49+
!! Read a 4-d real32 array from an HDF5 dataset.
50+
character(*), intent(in) :: filename
51+
!! HDF5 file name
52+
character(*), intent(in) :: object_name
53+
!! Object (dataset) name
54+
real(real32), allocatable, intent(in out) :: values(:,:,:,:)
55+
!! Array to store the dataset values into
56+
end subroutine get_hdf5_dataset_real32_4d
57+
4858
end interface get_hdf5_dataset
4959

5060
end module nf_io_hdf5

src/nf/io/nf_io_hdf5_submodule.f90

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,53 @@ module subroutine get_hdf5_dataset_real32_2d(filename, object_name, values)
9999
call f % read(object_name, values)
100100
call f % close()
101101

102+
! Transpose the array to get from C to Fortran order
103+
values = transpose(values)
104+
102105
end subroutine get_hdf5_dataset_real32_2d
103106

107+
108+
module subroutine get_hdf5_dataset_real32_4d(filename, object_name, values)
109+
110+
character(*), intent(in) :: filename
111+
character(*), intent(in) :: object_name
112+
real(real32), allocatable, intent(in out) :: values(:,:,:,:)
113+
114+
type(hdf5_file) :: f
115+
integer(int64), allocatable :: dims(:)
116+
117+
call f % open(filename, 'r')
118+
call f % shape(object_name, dims)
119+
120+
! If values is already allocated, re-allocate only if incorrect shape
121+
if (allocated(values)) then
122+
if (.not. all(shape(values) == dims)) then
123+
deallocate(values)
124+
allocate(values(dims(1), dims(2), dims(3), dims(4)))
125+
end if
126+
else
127+
allocate(values(dims(1), dims(2), dims(3), dims(4)))
128+
end if
129+
130+
call f % read(object_name, values)
131+
call f % close()
132+
133+
! Transpose the array to get from C to Fortran order
134+
values = reverse_dim_order(values)
135+
136+
end subroutine get_hdf5_dataset_real32_4d
137+
138+
139+
pure function reverse_dim_order(x) result(res)
140+
real, intent(in) :: x(:,:,:,:)
141+
real, allocatable :: res(:,:,:,:)
142+
integer :: dims(4)
143+
integer :: i, j, k, l
144+
dims = shape(x)
145+
allocate(res(dims(4), dims(3), dims(2), dims(1)))
146+
do concurrent(i = 1:dims(1), j = 1:dims(2), k = 1:dims(3), l = 1:dims(4))
147+
res(l,k,j,i) = x(i,j,k,l)
148+
end do
149+
end function reverse_dim_order
150+
104151
end submodule nf_io_hdf5_submodule

src/nf/nf_network_submodule.f90

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
submodule(nf_network) nf_network_submodule
22

3+
use nf_conv2d_layer, only: conv2d_layer
34
use nf_dense_layer, only: dense_layer
45
use nf_flatten_layer, only: flatten_layer
56
use nf_input1d_layer, only: input1d_layer
67
use nf_input3d_layer, only: input3d_layer
8+
use nf_maxpool2d_layer, only: maxpool2d_layer
79
use nf_io_hdf5, only: get_hdf5_dataset
810
use nf_keras, only: get_keras_h5_layers, keras_layer
911
use nf_layer, only: layer
@@ -131,6 +133,17 @@ module function network_from_keras(filename) result(res)
131133

132134
select type(this_layer => res % layers(n) % p)
133135

136+
type is(conv2d_layer)
137+
! Read biases from file
138+
object_name = '/model_weights/' // layer_name // '/' &
139+
// layer_name // '/bias:0'
140+
call get_hdf5_dataset(filename, object_name, this_layer % biases)
141+
142+
! Read weights from file
143+
object_name = '/model_weights/' // layer_name // '/' &
144+
// layer_name // '/kernel:0'
145+
call get_hdf5_dataset(filename, object_name, this_layer % kernel)
146+
134147
type is(dense_layer)
135148

136149
! Read biases from file
@@ -143,12 +156,13 @@ module function network_from_keras(filename) result(res)
143156
// layer_name // '/kernel:0'
144157
call get_hdf5_dataset(filename, object_name, this_layer % weights)
145158

146-
! TODO Multidimensional arrays are stored in HDF5 in C-order.
147-
! TODO Here we transpose the array to get to the Fortran order.
148-
! TODO There may be a way to do this without re-allocating.
149-
! TODO It probably doesn't matter much since we do this once.
150-
! TODO Figure it out later.
151-
this_layer % weights = transpose(this_layer % weights)
159+
type is(flatten_layer)
160+
! Nothing to do
161+
continue
162+
163+
type is(maxpool2d_layer)
164+
! Nothing to do
165+
continue
152166

153167
class default
154168
error stop 'Internal error in network_from_keras(); ' &

test/test_cnn_from_keras.f90

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,51 @@ program test_cnn_from_keras
1616

1717
net = network(test_data_path)
1818

19+
block
20+
21+
use nf, only: load_mnist, label_digits
22+
23+
real, allocatable :: training_images(:,:), training_labels(:)
24+
real, allocatable :: validation_images(:,:), validation_labels(:)
25+
real, allocatable :: testing_images(:,:), testing_labels(:)
26+
real :: acc
27+
28+
call load_mnist(training_images, training_labels, &
29+
validation_images, validation_labels, &
30+
testing_images, testing_labels)
31+
32+
acc = accuracy(net, reshape(testing_images, shape=[1,28,28,10000]), label_digits(testing_labels))
33+
print *, acc
34+
35+
if (acc < 0.94) then
36+
write(stderr, '(a)') &
37+
'Pre-trained network accuracy should be > 0.94.. failed'
38+
ok = .false.
39+
end if
40+
41+
end block
42+
43+
if (ok) then
44+
print '(a)', 'test_cnn_from_keras: All tests passed.'
45+
else
46+
write(stderr, '(a)') &
47+
'test_cnn_from_keras: One or more tests failed.'
48+
stop 1
49+
end if
50+
51+
contains
52+
53+
real function accuracy(net, x, y)
54+
type(network), intent(in out) :: net
55+
real, intent(in) :: x(:,:,:,:), y(:,:)
56+
integer :: i, good
57+
good = 0
58+
do i = 1, size(x, dim=2)
59+
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
60+
good = good + 1
61+
end if
62+
end do
63+
accuracy = real(good) / size(x, dim=2)
64+
end function accuracy
65+
1966
end program test_cnn_from_keras

0 commit comments

Comments
 (0)