Skip to content

Commit 49b154f

Browse files
Implement DataLoader to provide two copies of image
Changes to be committed: modified: .gitignore new file: dataloader.py new file: main.py
1 parent c9b16ef commit 49b154f

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
env/
2+
.idea/
3+
data/
4+
__pycache__

dataloader.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
3+
import torch
4+
import torchvision.transforms as transforms
5+
from PIL import Image
6+
import numpy as np
7+
8+
9+
class ImageDataset(torch.utils.data.Dataset):
10+
def __init__(self, path, transform=None, target_transform=None):
11+
folders = os.listdir(path)
12+
self.img_names = [os.path.join(path, folder, name) for folder in folders
13+
for name in os.listdir(os.path.join(path, folder))]
14+
self.transform = transform
15+
self.target_transform = target_transform
16+
17+
def __len__(self):
18+
return len(self.img_names)
19+
20+
def __getitem__(self, index):
21+
x = Image.open(self.img_names[index])
22+
xo = x.copy()
23+
24+
if self.transform:
25+
x = self.transform(x)
26+
if self.target_transform:
27+
xo = self.target_transform(xo)
28+
xo = np.array(xo)
29+
return x, xo
30+
31+
32+
def get_dataloader(path, scale_size=256, crop_size=224, batch_size=4, shuffle=True, num_workers=4):
33+
transformer = transforms.Compose([
34+
transforms.Resize(scale_size),
35+
transforms.CenterCrop(crop_size),
36+
transforms.ToTensor()
37+
])
38+
target_transformer = transforms.Compose([
39+
transforms.Resize(scale_size),
40+
transforms.CenterCrop(crop_size),
41+
])
42+
43+
dataset = ImageDataset(path, transformer, target_transformer)
44+
45+
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import matplotlib.pyplot as plt
2+
3+
import dataloader
4+
5+
6+
def main():
7+
datagen = dataloader.get_dataloader('data')
8+
for x, xo in datagen:
9+
print(xo.shape)
10+
axes = plt.gca()
11+
axes.axis('off')
12+
axes.imshow(xo[0].numpy())
13+
plt.show()
14+
break
15+
16+
17+
if __name__ == '__main__':
18+
main()

0 commit comments

Comments
 (0)