|
10 | 10 |
|
11 | 11 | implicit none
|
12 | 12 |
|
| 13 | + integer, parameter :: message_len = 128 |
| 14 | + |
13 | 15 | contains
|
14 | 16 |
|
| 17 | + subroutine download_and_uncompress() |
| 18 | + character(len=*), parameter :: download_mechanism = 'curl -LO ' |
| 19 | + character(len=*), parameter :: base_url='https://github.com/modern-fortran/neural-fortran/files/8498876/' |
| 20 | + character(len=*), parameter :: download_filename = 'mnist.tar.gz' |
| 21 | + character(len=*), parameter :: download_command = download_mechanism // base_url // download_filename |
| 22 | + character(len=*), parameter :: uncompress_file = 'tar xvzf ' // download_filename |
| 23 | + character(len=message_len) :: command_message |
| 24 | + character(len=:), allocatable :: error_message |
| 25 | + integer :: exit_status, command_status |
| 26 | + |
| 27 | + exit_status=0 |
| 28 | + call execute_command_line(command=download_command, wait=.true., & |
| 29 | + exitstat=exit_status, cmdstat=command_status, cmdmsg=command_message) |
| 30 | + |
| 31 | + if (any([exit_status, command_status] /= 0)) then |
| 32 | + error_message = 'command "' // download_command // '" failed' |
| 33 | + if (command_status /= 0) error_message = error_message // " with message " // trim(command_message) |
| 34 | + error stop error_message |
| 35 | + end if |
| 36 | + |
| 37 | + call execute_command_line(command=uncompress_file, wait=.true., & |
| 38 | + exitstat=exit_status, cmdstat=command_status, cmdmsg=command_message) |
| 39 | + |
| 40 | + if (any([exit_status, command_status] /= 0)) then |
| 41 | + error_message = 'command "' // uncompress_file // '" failed' |
| 42 | + if (command_status /= 0) error_message = error_message // " with message " // trim(command_message) |
| 43 | + error stop error_message |
| 44 | + end if |
| 45 | + |
| 46 | + end subroutine download_and_uncompress |
| 47 | + |
15 | 48 | pure module function label_digits(labels) result(res)
|
16 | 49 | real(rk), intent(in) :: labels(:)
|
17 | 50 | real(rk) :: res(10, size(labels))
|
@@ -42,21 +75,26 @@ module subroutine load_mnist(tr_images, tr_labels, te_images,&
|
42 | 75 | integer(ik), parameter :: tr_nimages = 50000
|
43 | 76 | integer(ik), parameter :: te_nimages = 10000
|
44 | 77 | integer(ik), parameter :: va_nimages = 10000
|
| 78 | + logical :: file_exists |
| 79 | + |
| 80 | + ! Check if MNIST data is present and download it if not. |
| 81 | + inquire(file='mnist_training_images.dat', exist=file_exists) |
| 82 | + if (.not. file_exists) call download_and_uncompress() |
45 | 83 |
|
46 |
| - call read_binary_file('data/mnist/mnist_training_images.dat',& |
| 84 | + call read_binary_file('mnist_training_images.dat',& |
47 | 85 | dtype, image_size, tr_nimages, tr_images)
|
48 |
| - call read_binary_file('data/mnist/mnist_training_labels.dat',& |
| 86 | + call read_binary_file('mnist_training_labels.dat',& |
49 | 87 | dtype, tr_nimages, tr_labels)
|
50 | 88 |
|
51 |
| - call read_binary_file('data/mnist/mnist_testing_images.dat',& |
| 89 | + call read_binary_file('mnist_testing_images.dat',& |
52 | 90 | dtype, image_size, te_nimages, te_images)
|
53 |
| - call read_binary_file('data/mnist/mnist_testing_labels.dat',& |
| 91 | + call read_binary_file('mnist_testing_labels.dat',& |
54 | 92 | dtype, te_nimages, te_labels)
|
55 | 93 |
|
56 | 94 | if (present(va_images) .and. present(va_labels)) then
|
57 |
| - call read_binary_file('data/mnist/mnist_validation_images.dat',& |
| 95 | + call read_binary_file('mnist_validation_images.dat',& |
58 | 96 | dtype, image_size, va_nimages, va_images)
|
59 |
| - call read_binary_file('data/mnist/mnist_validation_labels.dat',& |
| 97 | + call read_binary_file('mnist_validation_labels.dat',& |
60 | 98 | dtype, va_nimages, va_labels)
|
61 | 99 | end if
|
62 | 100 |
|
|
0 commit comments