-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
93 lines (75 loc) · 3.08 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
from utils import intersection_over_union
class YoloLoss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(YoloLoss,self).__init__()
self.mse = nn.MSELoss(reduction="sum")
self.S = S
self.B = B
self.C = C
# These are from Yolo paper, signifying how much we should
self.lambda_noobj = 0.5
self.lambda_coord = 5
# target: (BATCH_SIZE, S, S, 30)
# exists_box: (BATCH_SIZE, S, S, 1)
def forward(self,predictions, target):
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B*5)
# Calculate IoU for the two predicted bounding boxes with target bbox
# target[..., self.C+1:self.C+5] == [21:25]
iou_b1 = intersection_over_union(predictions[...,21:25], target[...,21:25])
# predictions[..., self.C + 6:self.C + 10], target[..., self.C + 1:self.C + 5]
iou_b2 = intersection_over_union(predictions[...,26:30], target[...,26:30])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
iou_maxes, bestbox = torch.max(ious, dim=0)
# obj_i
# (BATCH_SIZE, S, S, C + B*5) -> (BATCH_SIZE, S, S, 1,C + B*5)
exists_box = target[..., 20].unsqueeze(3)
box_predictions = exists_box * (
(
bestbox * predictions[..., 26:30]
+ (1 - bestbox) * predictions[..., 21:25]
)
)
box_targets = exists_box * target[..., 21:25]
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
torch.abs(box_predictions[..., 2:4] + 1e-6)
)
box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
# (N,S,S,4) -> (N*S*S, 4)
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_targets, end_dim=-2),
)
# FOR OBJECTS LOSS
pred_box = (
bestbox * predictions[...,25:26] + (1-bestbox)*predictions[...,20:21]
)
# (N*S*S,1)
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[..., 20:21])
)
# NO OBJECT LOSS
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
)
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
)
# Class loss
# (N, S, S, 20) -> (N*S*S, 20)
class_loss = self.mse(
torch.flatten(exists_box * predictions[..., :20], end_dim=-2),
torch.flatten(exists_box * target[...,:20], end_dim=-2),
)
# Loss
loss = (
self.lambda_coord * box_loss
+ object_loss
+ self.lambda_noobj * no_object_loss
+ class_loss
)
return loss