Skip to content

Commit 5592b9a

Browse files
committed
CNN from Keras test is passing
1 parent 2ff343e commit 5592b9a

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

test/test_cnn_from_keras.f90

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ program test_cnn_from_keras
2323
real, allocatable :: training_images(:,:), training_labels(:)
2424
real, allocatable :: validation_images(:,:), validation_labels(:)
2525
real, allocatable :: testing_images(:,:), testing_labels(:)
26+
real, allocatable :: input_reshaped(:,:,:,:)
2627
real :: acc
2728

2829
call load_mnist(training_images, training_labels, &
2930
validation_images, validation_labels, &
3031
testing_images, testing_labels)
3132

32-
acc = accuracy(net, reshape(testing_images, shape=[1,28,28,10000]), label_digits(testing_labels))
33-
print *, acc
33+
! Use only the first 1000 images to make the test short
34+
input_reshaped = reshape(testing_images(:,:1000), shape=[1,28,28,1000])
3435

35-
if (acc < 0.94) then
36+
acc = accuracy(net, input_reshaped, label_digits(testing_labels(:1000)))
37+
38+
if (acc < 0.97) then
3639
write(stderr, '(a)') &
37-
'Pre-trained network accuracy should be > 0.94.. failed'
40+
'Pre-trained network accuracy should be > 0.97.. failed'
3841
ok = .false.
3942
end if
4043

@@ -55,12 +58,12 @@ real function accuracy(net, x, y)
5558
real, intent(in) :: x(:,:,:,:), y(:,:)
5659
integer :: i, good
5760
good = 0
58-
do i = 1, size(x, dim=2)
61+
do i = 1, size(x, dim=4)
5962
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
6063
good = good + 1
6164
end if
6265
end do
63-
accuracy = real(good) / size(x, dim=2)
66+
accuracy = real(good) / size(x, dim=4)
6467
end function accuracy
6568

6669
end program test_cnn_from_keras

0 commit comments

Comments
 (0)