Skip to content

Commit 1b6b189

Browse files
committed
Refactor to store mnist data locally.
1 parent 09a8945 commit 1b6b189

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

content/tutorial-deep-learning-on-mnist.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,29 @@ filename = [["training_images", "train-images-idx3-ubyte.gz"], # 60,000 traini
6767
["test_labels", "t10k-labels-idx1-ubyte.gz"]] # 10,000 test labels.
6868
```
6969

70-
**2.** Download each of the 4 files in the list:
70+
**2.** Load the data. First check if the data is stored locally; if not, then
71+
download it.
7172

7273
```{code-cell} ipython3
7374
import requests
75+
import os
76+
77+
data_dir = "../_data"
78+
os.makedirs(data_dir, exist_ok=True)
7479
7580
base_url = "http://yann.lecun.com/exdb/mnist/"
7681
headers = {
7782
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0"
7883
}
7984
8085
for name in filename:
81-
print("Downloading file: " + name[1])
82-
resp = requests.get(base_url + name[1], headers=headers, stream=True)
83-
with open(name[1], "wb") as fh:
84-
for chunk in resp.iter_content(chunk_size=128):
85-
fh.write(chunk)
86+
fpath = os.path.join(data_dir, name[1])
87+
if not os.path.exists(fpath):
88+
print("Downloading file: " + name[1])
89+
resp = requests.get(base_url + name[1], headers=headers, stream=True)
90+
with open(fpath, "wb") as fh:
91+
for chunk in resp.iter_content(chunk_size=128):
92+
fh.write(chunk)
8693
```
8794

8895
**3.** Decompress the 4 files and create 4 [`ndarrays`](https://numpy.org/doc/stable/reference/arrays.ndarray.html), saving them into a dictionary. Each original image is of size 28x28 and neural networks normally expect a 1D vector input; therefore, you also need to reshape the images by multiplying 28 by 28 (784).
@@ -95,11 +102,11 @@ mnist_dataset = {}
95102
96103
# Images
97104
for name in filename[:2]:
98-
with gzip.open(name[1], 'rb') as mnist_file:
105+
with gzip.open(os.path.join(data_dir, name[1]), 'rb') as mnist_file:
99106
mnist_dataset[name[0]] = np.frombuffer(mnist_file.read(), np.uint8, offset=16).reshape(-1, 28*28)
100107
# Labels
101108
for name in filename[-2:]:
102-
with gzip.open(name[1], 'rb') as mnist_file:
109+
with gzip.open(os.path.join(data_dir, name[1]), 'rb') as mnist_file:
103110
mnist_dataset[name[0]] = np.frombuffer(mnist_file.read(), np.uint8, offset=8)
104111
```
105112

0 commit comments

Comments
 (0)