Skip to content

Commit 153a54c

Browse files
committed
refac(mod_mnist): mv procedure defs to submodule
This commit also makes a procedure that is referenced in only one place an internal procedure.
1 parent b1c8f38 commit 153a54c

File tree

2 files changed

+108
-72
lines changed

2 files changed

+108
-72
lines changed

src/mod_mnist.f90

Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ module mod_mnist
33
!! Procedures to work with MNIST dataset, usable with data format
44
!! as provided in this repo and not the original data format (idx).
55

6-
use iso_fortran_env, only: real32 !! TODO make MNIST work with arbitrary precision
7-
use mod_io, only: read_binary_file
86
use mod_kinds, only: ik, rk
97

108
implicit none
@@ -13,75 +11,33 @@ module mod_mnist
1311

1412
public :: label_digits, load_mnist, print_image
1513

16-
contains
17-
18-
pure function digits(x)
19-
!! Returns an array of 10 reals, with zeros everywhere
20-
!! and a one corresponding to the input number, for example:
21-
!! digits(0) = [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
22-
!! digits(1) = [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]
23-
!! digits(6) = [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]
24-
real(rk), intent(in) :: x
25-
real(rk) :: digits(10)
26-
digits = 0
27-
digits(int(x + 1)) = 1
28-
end function digits
29-
30-
pure function label_digits(labels) result(res)
31-
!! Converts an array of MNIST labels into a form
32-
!! that can be input to the network_type instance.
33-
real(rk), intent(in) :: labels(:)
34-
real(rk) :: res(10, size(labels))
35-
integer(ik) :: i
36-
do i = 1, size(labels)
37-
res(:,i) = digits(labels(i))
38-
end do
39-
end function label_digits
40-
41-
subroutine load_mnist(tr_images, tr_labels, te_images,&
42-
te_labels, va_images, va_labels)
43-
!! Loads the MNIST dataset into arrays.
44-
real(rk), allocatable, intent(in out) :: tr_images(:,:), tr_labels(:)
45-
real(rk), allocatable, intent(in out) :: te_images(:,:), te_labels(:)
46-
real(rk), allocatable, intent(in out), optional :: va_images(:,:), va_labels(:)
47-
integer(ik), parameter :: dtype = 4, image_size = 784
48-
integer(ik), parameter :: tr_nimages = 50000
49-
integer(ik), parameter :: te_nimages = 10000
50-
integer(ik), parameter :: va_nimages = 10000
51-
52-
call read_binary_file('data/mnist/mnist_training_images.dat',&
53-
dtype, image_size, tr_nimages, tr_images)
54-
call read_binary_file('data/mnist/mnist_training_labels.dat',&
55-
dtype, tr_nimages, tr_labels)
56-
57-
call read_binary_file('data/mnist/mnist_testing_images.dat',&
58-
dtype, image_size, te_nimages, te_images)
59-
call read_binary_file('data/mnist/mnist_testing_labels.dat',&
60-
dtype, te_nimages, te_labels)
61-
62-
if (present(va_images) .and. present(va_labels)) then
63-
call read_binary_file('data/mnist/mnist_validation_images.dat',&
64-
dtype, image_size, va_nimages, va_images)
65-
call read_binary_file('data/mnist/mnist_validation_labels.dat',&
66-
dtype, va_nimages, va_labels)
67-
end if
68-
69-
end subroutine load_mnist
70-
71-
subroutine print_image(images, labels, n)
72-
!! Prints a single image and label to screen.
73-
real(rk), intent(in) :: images(:,:), labels(:)
74-
integer(ik), intent(in) :: n
75-
real(rk) :: image(28, 28)
76-
character(len=1) :: char_image(28, 28)
77-
integer(ik) i, j
78-
image = reshape(images(:,n), [28, 28])
79-
char_image = '.'
80-
where (image > 0) char_image = '#'
81-
print *, labels(n)
82-
do j = 1, 28
83-
print *, char_image(:,j)
84-
end do
85-
end subroutine print_image
14+
interface
15+
16+
pure module function label_digits(labels) result(res)
17+
!! Converts an array of MNIST labels into a form
18+
!! that can be input to the network_type instance.
19+
implicit none
20+
real(rk), intent(in) :: labels(:)
21+
real(rk) :: res(10, size(labels))
22+
end function label_digits
23+
24+
module subroutine load_mnist(tr_images, tr_labels, te_images,&
25+
26+
te_labels, va_images, va_labels)
27+
!! Loads the MNIST dataset into arrays.
28+
implicit none
29+
real(rk), allocatable, intent(in out) :: tr_images(:,:), tr_labels(:)
30+
real(rk), allocatable, intent(in out) :: te_images(:,:), te_labels(:)
31+
real(rk), allocatable, intent(in out), optional :: va_images(:,:), va_labels(:)
32+
end subroutine load_mnist
33+
34+
module subroutine print_image(images, labels, n)
35+
!! Prints a single image and label to screen.
36+
implicit none
37+
real(rk), intent(in) :: images(:,:), labels(:)
38+
integer(ik), intent(in) :: n
39+
end subroutine print_image
40+
41+
end interface
8642

8743
end module mod_mnist

src/mod_mnist_submodule.f90

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
submodule(mod_mnist) mod_mnist_submodule
2+
3+
!! Procedures to work with MNIST dataset, usable with data format
4+
!! as provided in this repo and not the original data format (idx).
5+
6+
! TODO make MNIST work with arbitrary precision
7+
8+
use mod_io, only: read_binary_file
9+
use mod_kinds, only: ik, rk
10+
11+
implicit none
12+
13+
contains
14+
15+
pure module function label_digits(labels) result(res)
16+
real(rk), intent(in) :: labels(:)
17+
real(rk) :: res(10, size(labels))
18+
integer(ik) :: i
19+
do i = 1, size(labels)
20+
res(:,i) = digits(labels(i))
21+
end do
22+
contains
23+
pure function digits(x)
24+
!! Returns an array of 10 reals, with zeros everywhere
25+
!! and a one corresponding to the input number, for example:
26+
!! digits(0) = [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
27+
!! digits(1) = [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]
28+
!! digits(6) = [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]
29+
real(rk), intent(in) :: x
30+
real(rk) :: digits(10)
31+
digits = 0
32+
digits(int(x + 1)) = 1
33+
end function digits
34+
end function label_digits
35+
36+
module subroutine load_mnist(tr_images, tr_labels, te_images,&
37+
te_labels, va_images, va_labels)
38+
real(rk), allocatable, intent(in out) :: tr_images(:,:), tr_labels(:)
39+
real(rk), allocatable, intent(in out) :: te_images(:,:), te_labels(:)
40+
real(rk), allocatable, intent(in out), optional :: va_images(:,:), va_labels(:)
41+
integer(ik), parameter :: dtype = 4, image_size = 784
42+
integer(ik), parameter :: tr_nimages = 50000
43+
integer(ik), parameter :: te_nimages = 10000
44+
integer(ik), parameter :: va_nimages = 10000
45+
46+
call read_binary_file('data/mnist/mnist_training_images.dat',&
47+
dtype, image_size, tr_nimages, tr_images)
48+
call read_binary_file('data/mnist/mnist_training_labels.dat',&
49+
dtype, tr_nimages, tr_labels)
50+
51+
call read_binary_file('data/mnist/mnist_testing_images.dat',&
52+
dtype, image_size, te_nimages, te_images)
53+
call read_binary_file('data/mnist/mnist_testing_labels.dat',&
54+
dtype, te_nimages, te_labels)
55+
56+
if (present(va_images) .and. present(va_labels)) then
57+
call read_binary_file('data/mnist/mnist_validation_images.dat',&
58+
dtype, image_size, va_nimages, va_images)
59+
call read_binary_file('data/mnist/mnist_validation_labels.dat',&
60+
dtype, va_nimages, va_labels)
61+
end if
62+
63+
end subroutine load_mnist
64+
65+
module subroutine print_image(images, labels, n)
66+
real(rk), intent(in) :: images(:,:), labels(:)
67+
integer(ik), intent(in) :: n
68+
real(rk) :: image(28, 28)
69+
character(len=1) :: char_image(28, 28)
70+
integer(ik) i, j
71+
image = reshape(images(:,n), [28, 28])
72+
char_image = '.'
73+
where (image > 0) char_image = '#'
74+
print *, labels(n)
75+
do j = 1, 28
76+
print *, char_image(:,j)
77+
end do
78+
end subroutine print_image
79+
80+
end submodule mod_mnist_submodule

0 commit comments

Comments
 (0)