|
| 1 | +import os |
| 2 | + |
| 3 | +import torch |
| 4 | +import torchvision.transforms as transforms |
| 5 | +from PIL import Image |
| 6 | +import numpy as np |
| 7 | + |
| 8 | + |
| 9 | +class ImageDataset(torch.utils.data.Dataset): |
| 10 | + def __init__(self, path, transform=None, target_transform=None): |
| 11 | + folders = os.listdir(path) |
| 12 | + self.img_names = [os.path.join(path, folder, name) for folder in folders |
| 13 | + for name in os.listdir(os.path.join(path, folder))] |
| 14 | + self.transform = transform |
| 15 | + self.target_transform = target_transform |
| 16 | + |
| 17 | + def __len__(self): |
| 18 | + return len(self.img_names) |
| 19 | + |
| 20 | + def __getitem__(self, index): |
| 21 | + x = Image.open(self.img_names[index]) |
| 22 | + xo = x.copy() |
| 23 | + |
| 24 | + if self.transform: |
| 25 | + x = self.transform(x) |
| 26 | + if self.target_transform: |
| 27 | + xo = self.target_transform(xo) |
| 28 | + xo = np.array(xo) |
| 29 | + return x, xo |
| 30 | + |
| 31 | + |
| 32 | +def get_dataloader(path, scale_size=256, crop_size=224, batch_size=4, shuffle=True, num_workers=4): |
| 33 | + transformer = transforms.Compose([ |
| 34 | + transforms.Resize(scale_size), |
| 35 | + transforms.CenterCrop(crop_size), |
| 36 | + transforms.ToTensor() |
| 37 | + ]) |
| 38 | + target_transformer = transforms.Compose([ |
| 39 | + transforms.Resize(scale_size), |
| 40 | + transforms.CenterCrop(crop_size), |
| 41 | + ]) |
| 42 | + |
| 43 | + dataset = ImageDataset(path, transformer, target_transformer) |
| 44 | + |
| 45 | + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) |
0 commit comments