-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathdatasets.py
executable file
·90 lines (68 loc) · 2.55 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Datasets"""
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import glob
import PIL
import random
import scipy.io as scio
class FFHQ_Mat(Dataset):
"""FFHQ Mat Dataset"""
# def __init__(self, dataset_path, posepath, lmpath, bfmpath, img_size, **kwargs):
def __init__(self, dataset_path, posepath, img_size, **kwargs):
super().__init__()
self.data = glob.glob(dataset_path)
self.posepath = posepath
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size), interpolation=0)])
assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_name = self.data[index]
mat = scio.loadmat(os.path.join(self.posepath, image_name.split('/')[-1].replace('png', 'mat')))
img = PIL.Image.open(image_name)
img = self.transform(img)
return img, mat
class VAE_Mat(Dataset):
def __init__(self, ryspath, ffhqpath, **kwargs):
super().__init__()
self.rydata = glob.glob(ryspath)
self.ffhqdata = glob.glob(ffhqpath)
assert len(self.rydata) > 0, "Can't find ryser data; make sure you specify the path to your dataset"
assert len(self.ffhqdata) > 0, "Can't find ffhq data; make sure you specify the path to your dataset"
self.data = self.rydata + self.ffhqdata
random.shuffle(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
mat_name = self.data[index]
mat = scio.loadmat(mat_name)
return mat
def get_dataset(name, subsample=None, batch_size=1, **kwargs):
dataset = globals()[name](**kwargs)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
pin_memory=False,
num_workers=0
)
return dataloader, len(dataset)
def get_dataset_distributed(name, world_size, rank, batch_size, **kwargs):
dataset = globals()[name](**kwargs)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
)
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=False,
num_workers=4,
)
return dataloader, len(dataset)