-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathdata_loaders.py
executable file
·127 lines (114 loc) · 6 KB
/
data_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
### data_loaders.py
import os
import numpy as np
import pandas as pd
from PIL import Image
from sklearn import preprocessing
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset # For custom datasets
from torchvision import datasets, transforms
################
# Dataset Loader
################
class PathgraphomicDatasetLoader(Dataset):
def __init__(self, opt, data, split, mode='omic'):
"""
Args:
X = data
e = overall survival event
t = overall survival in months
"""
self.X_path = data[split]['x_path']
self.X_grph = data[split]['x_grph']
self.X_omic = data[split]['x_omic']
self.e = data[split]['e']
self.t = data[split]['t']
self.g = data[split]['g']
self.mode = mode
self.transforms = transforms.Compose([
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5),
transforms.RandomCrop(opt.input_size_path),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.01),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def __getitem__(self, index):
single_e = torch.tensor(self.e[index]).type(torch.FloatTensor)
single_t = torch.tensor(self.t[index]).type(torch.FloatTensor)
single_g = torch.tensor(self.g[index]).type(torch.LongTensor)
if self.mode == "path" or self.mode == 'pathpath':
single_X_path = Image.open(self.X_path[index]).convert('RGB')
return (self.transforms(single_X_path), 0, 0, single_e, single_t, single_g)
elif self.mode == "graph" or self.mode == 'graphgraph':
single_X_grph = torch.load(self.X_grph[index])
return (0, single_X_grph, 0, single_e, single_t, single_g)
elif self.mode == "omic" or self.mode == 'omicomic':
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (0, 0, single_X_omic, single_e, single_t, single_g)
elif self.mode == "pathomic":
single_X_path = Image.open(self.X_path[index]).convert('RGB')
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (self.transforms(single_X_path), 0, single_X_omic, single_e, single_t, single_g)
elif self.mode == "graphomic":
single_X_grph = torch.load(self.X_grph[index])
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (0, single_X_grph, single_X_omic, single_e, single_t, single_g)
elif self.mode == "pathgraph":
single_X_path = Image.open(self.X_path[index]).convert('RGB')
single_X_grph = torch.load(self.X_grph[index])
return (self.transforms(single_X_path), single_X_grph, 0, single_e, single_t, single_g)
elif self.mode == "pathgraphomic":
single_X_path = Image.open(self.X_path[index]).convert('RGB')
single_X_grph = torch.load(self.X_grph[index])
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (self.transforms(single_X_path), single_X_grph, single_X_omic, single_e, single_t, single_g)
def __len__(self):
return len(self.X_path)
class PathgraphomicFastDatasetLoader(Dataset):
def __init__(self, opt, data, split, mode='omic'):
"""
Args:
X = data
e = overall survival event
t = overall survival in months
"""
self.X_path = data[split]['x_path']
self.X_grph = data[split]['x_grph']
self.X_omic = data[split]['x_omic']
self.e = data[split]['e']
self.t = data[split]['t']
self.g = data[split]['g']
self.mode = mode
def __getitem__(self, index):
single_e = torch.tensor(self.e[index]).type(torch.FloatTensor)
single_t = torch.tensor(self.t[index]).type(torch.FloatTensor)
single_g = torch.tensor(self.g[index]).type(torch.LongTensor)
if self.mode == "path" or self.mode == 'pathpath':
single_X_path = torch.tensor(self.X_path[index]).type(torch.FloatTensor).squeeze(0)
return (single_X_path, 0, 0, single_e, single_t, single_g)
elif self.mode == "graph" or self.mode == 'graphgraph':
single_X_grph = torch.load(self.X_grph[index])
return (0, single_X_grph, 0, single_e, single_t, single_g)
elif self.mode == "omic" or self.mode == 'omicomic':
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (0, 0, single_X_omic, single_e, single_t, single_g)
elif self.mode == "pathomic":
single_X_path = torch.tensor(self.X_path[index]).type(torch.FloatTensor).squeeze(0)
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (single_X_path, 0, single_X_omic, single_e, single_t, single_g)
elif self.mode == "graphomic":
single_X_grph = torch.load(self.X_grph[index])
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (0, single_X_grph, single_X_omic, single_e, single_t, single_g)
elif self.mode == "pathgraph":
single_X_path = torch.tensor(self.X_path[index]).type(torch.FloatTensor).squeeze(0)
single_X_grph = torch.load(self.X_grph[index])
return (single_X_path, single_X_grph, 0, single_e, single_t, single_g)
elif self.mode == "pathgraphomic":
single_X_path = torch.tensor(self.X_path[index]).type(torch.FloatTensor).squeeze(0)
single_X_grph = torch.load(self.X_grph[index])
single_X_omic = torch.tensor(self.X_omic[index]).type(torch.FloatTensor)
return (single_X_path, single_X_grph, single_X_omic, single_e, single_t, single_g)
def __len__(self):
return len(self.X_path)