-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset_DM.py
109 lines (91 loc) · 3.11 KB
/
dataset_DM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
import numpy as np
import os
import pandas as pd
from torch.utils.data import Dataset
class TrainDataset(Dataset):
def __init__(self, path, width=512):
self.path = path
self.data = pd.read_csv(os.path.join(self.path, 'train.csv'))
self.width = width
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data.iloc[idx]
PID = item['PID']
CF_path = item['CF_path']
OCT_path = item['OCT_path']
source = cv2.imread(os.path.join(self.path, 'CF', CF_path))
h, w, _ = source.shape
source = source[:, w//2-h//2:w//2+h//2, :]
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
source = cv2.resize(source, (self.width, self.width))
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
target = np.zeros((self.width, self.width, 6))
for i in range(6):
target_i = cv2.imread(os.path.join(self.path, 'OCT', OCT_path, f"{OCT_path}_{i}.jpg"), cv2.IMREAD_GRAYSCALE)
if idx <= 728:
target_i = target_i[:496,-768:]
target_i = cv2.resize(target_i, (self.width, self.width))
target[:, :, i] = target_i
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(
jpg=target,
txt="",
hint=source,
id=idx,
PID=PID,
CF_path=CF_path,
OCT_path=OCT_path
)
class ValidDataset(Dataset):
def __init__(self, path, width=512):
self.path = path
self.data = pd.read_csv(os.path.join(self.path, 'val.csv'))
self.width = width
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data.iloc[idx]
PID = item['PID']
CF_path = item['CF_path']
source = cv2.imread(os.path.join(self.path, 'CF', CF_path))
h, w, _ = source.shape
source = source[:, w//2-h//2:w//2+h//2, :]
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
source = cv2.resize(source, (self.width, self.width))
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = np.zeros((self.width, self.width, 6))
return dict(
jpg=target,
txt="",
hint=source,
id=idx,
PID=PID,
CF_path=CF_path,
)
if __name__ == '__main__':
data_path = "/home/pod/shared-nvme/data/EyeOCT/train"
train_dataset = TrainDataset(data_path)
print(len(train_dataset)) # 750
item = train_dataset[0]
txt = item['txt']
jpg = item['jpg']
# hint_global = item['hint'][0]
# hint_local = item['hint'][1]
id_ = item['id']
PID = item['PID']
CF_path = item['CF_path']
OCT_path = item['OCT_path']
print(txt)
print(jpg.shape)
# print(hint_global.shape)
# print(hint_local.shape)
print(id_)
print(PID)
print(CF_path)
print(OCT_path)