forked from Usman-Rafique/Probabilistic_UNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_LIDC_data.py
executable file
·69 lines (56 loc) · 2.2 KB
/
load_LIDC_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
import random
import pickle
#import dicom
class LIDC_IDRI(Dataset):
images = []
labels = []
series_uid = []
def __init__(self, dataset_location, transform=None):
self.transform = transform
max_bytes = 2**31 - 1
data = {}
for file in os.listdir(dataset_location):
filename = os.fsdecode(file)
if '.pickle' in filename:
print("Loading file", filename)
file_path = dataset_location + filename
bytes_in = bytearray(0)
input_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f_in:
for _ in range(0, input_size, max_bytes):
bytes_in += f_in.read(max_bytes)
new_data = pickle.loads(bytes_in)
data.update(new_data)
for key, value in data.items():
self.images.append(value['image'].astype(float))
self.labels.append(value['masks'])
self.series_uid.append(value['series_uid'])
assert (len(self.images) == len(self.labels) == len(self.series_uid))
for img in self.images:
assert np.max(img) <= 1 and np.min(img) >= 0
for label in self.labels:
assert np.max(label) <= 1 and np.min(label) >= 0
del new_data
del data
def __getitem__(self, index):
image = np.expand_dims(self.images[index], axis=0)
#Randomly select one of the four labels for this image
label = self.labels[index][random.randint(0,3)].astype(float)
if self.transform is not None:
image = self.transform(image)
series_uid = self.series_uid[index]
# Convert image and label to torch tensors
image = torch.from_numpy(image)
label = torch.from_numpy(label)
#Convert uint8 to float tensors
image = image.type(torch.FloatTensor)
label = label.type(torch.FloatTensor)
return image, label, series_uid
# Override to give PyTorch size of dataset
def __len__(self):
return len(self.images)