-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdata_loader.py
105 lines (86 loc) · 2.9 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_data_loaders(data_dir,
batch_size,
train_transform,
test_transform,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Adapted from: https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb
Utility function for loading and returning train and test
multi-process iterators over the CIFAR-10 dataset.
If using CUDA, set pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- train_transform: pytorch transforms for the training set
- test_transform: pytorch transofrms for the test set
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- test_loader: test set iterator.
"""
# Load the datasets
train_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_transform,
)
test_dataset = datasets.CIFAR10(
root=data_dir, train=False,
download=True, transform=test_transform,
)
# Create loader objects
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory
)
return (train_loader, test_loader)
def plot_images(images, cls_true, cls_pred=None):
"""
Plot images with labels.
Adapted from https://github.com/Hvass-Labs/TensorFlow-Tutorials/
"""
# CIFAR10 labels
label_names = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
# plot img
ax.imshow(images[i, :, :, :], interpolation='spline16')
# show true & predicted classes
cls_true_name = label_names[cls_true[i]]
if cls_pred is None:
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
else:
cls_pred_name = label_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(
cls_true_name, cls_pred_name
)
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()