-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
106 lines (89 loc) · 3.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import config
import numpy as np
import os
import pandas as pd
from torchvision.transforms import transforms
import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from utils import (
cells_to_bboxes,
iou_width_height as iou,
non_max_suppression as nms,
plot_image
)
ImageFile.LOAD_TRUNCATED_IMAGES = True
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, bboxes):
for t in self.transforms:
image, bboxes = t(image), bboxes
return image, bboxes
transform = Compose([transforms.Resize((416, 416)),transforms.ToTensor() ])
class HatDataset(Dataset):
def __init__(
self,
csv_file,
img_dir,
label_dir,
anchors,
image_size=416,
S=[13, 26, 52],
C=2,
transform=None,
):
self.annotations = pd.read_csv(csv_file)
self.img_dir = img_dir
self.label_dir = label_dir
self.image_size = image_size
self.transform = transform
self.S = S
self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])
self.num_anchors = self.anchors.shape[0]
self.num_anchors_per_scale = self.num_anchors // 3
self.C = C
self.ignore_iou_thresh = 0.5
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
label_path = os.path.join(
self.label_dir, self.annotations.iloc[index, 1])
bboxes = np.roll(np.loadtxt(fname=label_path,
delimiter=" ", ndmin=2), 4, axis=1).tolist()
img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
image = (Image.open(img_path).convert("RGB"))
if self.transform:
image, bboxes = self.transform(image=image, bboxes=bboxes)
targets = [torch.zeros((self.num_anchors // 3, S, S, 6))
for S in self.S]
for box in bboxes:
iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
anchor_indices = iou_anchors.argsort(descending=True, dim=0)
x, y, width, height, class_label = box
has_anchor = [False] * 3
for anchor_idx in anchor_indices:
scale_idx = anchor_idx // self.num_anchors_per_scale
anchor_on_scale = anchor_idx % self.num_anchors_per_scale
S = self.S[scale_idx]
i, j = int(S * y), int(S * x)
anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
if not anchor_taken and not has_anchor[scale_idx]:
targets[scale_idx][anchor_on_scale, i, j, 0] = 1
x_cell, y_cell = S * x - j, S * y - i
width_cell, height_cell = (
width * S,
height * S,
)
box_coordinates = torch.tensor(
[x_cell, y_cell, width_cell, height_cell]
)
targets[scale_idx][anchor_on_scale,
i, j, 1:5] = box_coordinates
targets[scale_idx][anchor_on_scale,
i, j, 5] = int(class_label)
has_anchor[scale_idx] = True
elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
targets[scale_idx][anchor_on_scale,
i, j, 0] = -1
return image, tuple(targets)