@@ -3,8 +3,6 @@ module mod_mnist
3
3
! ! Procedures to work with MNIST dataset, usable with data format
4
4
! ! as provided in this repo and not the original data format (idx).
5
5
6
- use iso_fortran_env, only: real32 ! ! TODO make MNIST work with arbitrary precision
7
- use mod_io, only: read_binary_file
8
6
use mod_kinds, only: ik, rk
9
7
10
8
implicit none
@@ -13,75 +11,33 @@ module mod_mnist
13
11
14
12
public :: label_digits, load_mnist, print_image
15
13
16
- contains
17
-
18
- pure function digits (x )
19
- ! ! Returns an array of 10 reals, with zeros everywhere
20
- ! ! and a one corresponding to the input number, for example:
21
- ! ! digits(0) = [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
22
- ! ! digits(1) = [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]
23
- ! ! digits(6) = [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]
24
- real (rk), intent (in ) :: x
25
- real (rk) :: digits (10 )
26
- digits = 0
27
- digits (int (x + 1 )) = 1
28
- end function digits
29
-
30
- pure function label_digits (labels ) result(res)
31
- ! ! Converts an array of MNIST labels into a form
32
- ! ! that can be input to the network_type instance.
33
- real (rk), intent (in ) :: labels(:)
34
- real (rk) :: res(10 , size (labels))
35
- integer (ik) :: i
36
- do i = 1 , size (labels)
37
- res(:,i) = digits (labels(i))
38
- end do
39
- end function label_digits
40
-
41
- subroutine load_mnist (tr_images , tr_labels , te_images ,&
42
- te_labels , va_images , va_labels )
43
- ! ! Loads the MNIST dataset into arrays.
44
- real (rk), allocatable , intent (in out ) :: tr_images(:,:), tr_labels(:)
45
- real (rk), allocatable , intent (in out ) :: te_images(:,:), te_labels(:)
46
- real (rk), allocatable , intent (in out ), optional :: va_images(:,:), va_labels(:)
47
- integer (ik), parameter :: dtype = 4 , image_size = 784
48
- integer (ik), parameter :: tr_nimages = 50000
49
- integer (ik), parameter :: te_nimages = 10000
50
- integer (ik), parameter :: va_nimages = 10000
51
-
52
- call read_binary_file(' data/mnist/mnist_training_images.dat' ,&
53
- dtype, image_size, tr_nimages, tr_images)
54
- call read_binary_file(' data/mnist/mnist_training_labels.dat' ,&
55
- dtype, tr_nimages, tr_labels)
56
-
57
- call read_binary_file(' data/mnist/mnist_testing_images.dat' ,&
58
- dtype, image_size, te_nimages, te_images)
59
- call read_binary_file(' data/mnist/mnist_testing_labels.dat' ,&
60
- dtype, te_nimages, te_labels)
61
-
62
- if (present (va_images) .and. present (va_labels)) then
63
- call read_binary_file(' data/mnist/mnist_validation_images.dat' ,&
64
- dtype, image_size, va_nimages, va_images)
65
- call read_binary_file(' data/mnist/mnist_validation_labels.dat' ,&
66
- dtype, va_nimages, va_labels)
67
- end if
68
-
69
- end subroutine load_mnist
70
-
71
- subroutine print_image (images , labels , n )
72
- ! ! Prints a single image and label to screen.
73
- real (rk), intent (in ) :: images(:,:), labels(:)
74
- integer (ik), intent (in ) :: n
75
- real (rk) :: image(28 , 28 )
76
- character (len= 1 ) :: char_image(28 , 28 )
77
- integer (ik) i, j
78
- image = reshape (images(:,n), [28 , 28 ])
79
- char_image = ' .'
80
- where (image > 0 ) char_image = ' #'
81
- print * , labels(n)
82
- do j = 1 , 28
83
- print * , char_image(:,j)
84
- end do
85
- end subroutine print_image
14
+ interface
15
+
16
+ pure module function label_digits(labels) result(res)
17
+ ! ! Converts an array of MNIST labels into a form
18
+ ! ! that can be input to the network_type instance.
19
+ implicit none
20
+ real (rk), intent (in ) :: labels(:)
21
+ real (rk) :: res(10 , size (labels))
22
+ end function label_digits
23
+
24
+ module subroutine load_mnist (tr_images , tr_labels , te_images ,&
25
+
26
+ te_labels , va_images , va_labels )
27
+ ! ! Loads the MNIST dataset into arrays.
28
+ implicit none
29
+ real (rk), allocatable , intent (in out ) :: tr_images(:,:), tr_labels(:)
30
+ real (rk), allocatable , intent (in out ) :: te_images(:,:), te_labels(:)
31
+ real (rk), allocatable , intent (in out ), optional :: va_images(:,:), va_labels(:)
32
+ end subroutine load_mnist
33
+
34
+ module subroutine print_image (images , labels , n )
35
+ ! ! Prints a single image and label to screen.
36
+ implicit none
37
+ real (rk), intent (in ) :: images(:,:), labels(:)
38
+ integer (ik), intent (in ) :: n
39
+ end subroutine print_image
40
+
41
+ end interface
86
42
87
43
end module mod_mnist
0 commit comments