-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader.py
77 lines (73 loc) · 2.8 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import sys
import pickle
import tarfile
import numpy as np
from urllib.request import urlretrieve
from sklearn.preprocessing import OneHotEncoder
def read_batch(src):
'''Unpack the pickle files'''
with open(src, 'rb') as f:
if sys.version_info.major == 2:
data = pickle.load(f)
else:
data = pickle.load(f, encoding='latin1') # Contains the numpy array
return data
def process_cifar():
'''Read data into RAM'''
print('Preparing train set...')
train_list = [read_batch('./cifar-10-batches-py/data_batch_{0}'.format(i + 1)) for i in range(5)]
x_train = np.concatenate([x['data'] for x in train_list])
y_train = np.concatenate([y['labels'] for y in train_list])
print('Preparing test set...')
tst = read_batch('./cifar-10-batches-py/test_batch')
x_test = tst['data']
y_test = np.asarray(tst['labels'])
return x_train, x_test, y_train, y_test
def maybe_download_cifar(src="http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"):
'''Load the training and testing data'''
try:
return process_cifar()
except:
# Catch the exception that file doesn't exist & Download
print('Data does not exist. Downloading ' + src)
filename = src.split('/')[-1]
filepath = os.path.join("./",filename)
def _recall_func(num,block_size,total_size):
sys.stdout.write('\r>> downloading %s %.1f%%' % (filename,float(num*block_size)/float(total_size)*100.0))
sys.stdout.flush()
fname, h = urlretrieve(src, filepath,_recall_func)
file_info = os.stat(filepath)
print('Successfully download.',filename,file_info.st_size,'bytes')
print('Extracting files...')
with tarfile.open(fname) as tar:
tar.extractall()
os.remove(fname)
return process_cifar()
def load_cifar(channel_first=True, one_hot=False):
# Raw data
x_train, x_test, y_train, y_test = maybe_download_cifar()
# Scale pixel intensity
x_train = x_train / 255.0
x_test = x_test / 255.0
# Reshape
x_train = x_train.reshape(-1, 3, 32, 32)
x_test = x_test.reshape(-1, 3, 32, 32)
# Channel last
if not channel_first:
x_train = np.swapaxes(x_train, 1, 3)
x_test = np.swapaxes(x_test, 1, 3)
# One-hot encode y
if one_hot:
y_train = np.expand_dims(y_train, axis=-1)
y_test = np.expand_dims(y_test, axis=-1)
enc = OneHotEncoder(categorical_features='all')
fit = enc.fit(y_train)
y_train = fit.transform(y_train).toarray()
y_test = fit.transform(y_test).toarray()
# dtypes
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
return x_train, x_test, y_train, y_test