Skip to content

Commit 148892b

Browse files
committed
Add CNN from Keras example; rename the dense from Keras example
1 parent 0941ebb commit 148892b

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

example/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
foreach(execid cnn mnist mnist_from_keras simple sine)
1+
foreach(execid
2+
cnn
3+
cnn_from_keras
4+
dense_from_keras
5+
mnist
6+
simple
7+
sine
8+
)
29
add_executable(${execid} ${execid}.f90)
3-
target_link_libraries(${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
10+
target_link_libraries(${execid} PRIVATE
11+
neural
12+
h5fortran::h5fortran
13+
jsonfortran::jsonfortran
14+
${LIBS}
15+
)
416
endforeach()

example/cnn_from_keras.f90

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
program cnn_from_keras
2+
3+
! This example demonstrates loading a convolutional model
4+
! pre-trained on the MNIST dataset from a Keras HDF5
5+
! file and running an inferrence on the testing dataset.
6+
7+
use nf, only: network, label_digits, load_mnist
8+
use nf_datasets, only: download_and_unpack, keras_cnn_mnist_url
9+
10+
implicit none
11+
12+
type(network) :: net
13+
real, allocatable :: training_images(:,:), training_labels(:)
14+
real, allocatable :: validation_images(:,:), validation_labels(:)
15+
real, allocatable :: testing_images(:,:), testing_labels(:)
16+
character(*), parameter :: keras_cnn_path = 'keras_cnn_mnist.h5'
17+
logical :: file_exists
18+
real :: acc
19+
20+
inquire(file=keras_cnn_path, exist=file_exists)
21+
if (.not. file_exists) call download_and_unpack(keras_cnn_mnist_url)
22+
23+
call load_mnist(training_images, training_labels, &
24+
validation_images, validation_labels, &
25+
testing_images, testing_labels)
26+
27+
print '("Loading a pre-trained CNN model from Keras")'
28+
print '(60("="))'
29+
30+
net = network(keras_cnn_path)
31+
32+
call net % print_info()
33+
34+
if (this_image() == 1) then
35+
acc = accuracy( &
36+
net, &
37+
reshape(testing_images(:,:), shape=[1,28,28,size(testing_images,2)]), &
38+
label_digits(testing_labels) &
39+
)
40+
print '(a,f5.2,a)', 'Accuracy: ', acc * 100, ' %'
41+
end if
42+
43+
contains
44+
45+
real function accuracy(net, x, y)
46+
type(network), intent(in out) :: net
47+
real, intent(in) :: x(:,:,:,:), y(:,:)
48+
integer :: i, good
49+
good = 0
50+
do i = 1, size(x, dim=4)
51+
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
52+
good = good + 1
53+
end if
54+
end do
55+
accuracy = real(good) / size(x, dim=4)
56+
end function accuracy
57+
58+
end program cnn_from_keras

example/mnist_from_keras.f90 renamed to example/dense_from_keras.f90

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
program mnist_from_keras
1+
program dense_from_keras
22

3-
! This example demonstrates loading a pre-trained MNIST model from Keras
4-
! from an HDF5 file and running an inferrence on the testing dataset.
3+
! This example demonstrates loading a dense model
4+
! pre-trained on the MNIST dataset from a Keras HDF5
5+
! file and running an inferrence on the testing dataset.
56

67
use nf, only: network, label_digits, load_mnist
78
use nf_datasets, only: download_and_unpack, keras_dense_mnist_url
@@ -22,7 +23,7 @@ program mnist_from_keras
2223
validation_images, validation_labels, &
2324
testing_images, testing_labels)
2425

25-
print '("Loading a pre-trained MNIST model from Keras")'
26+
print '("Loading a pre-trained dense model from Keras")'
2627
print '(60("="))'
2728

2829
net = network(keras_dense_path)
@@ -48,4 +49,4 @@ real function accuracy(net, x, y)
4849
accuracy = real(good) / size(x, dim=2)
4950
end function accuracy
5051

51-
end program mnist_from_keras
52+
end program dense_from_keras

0 commit comments

Comments
 (0)