-
Notifications
You must be signed in to change notification settings - Fork 662
/
Copy pathmargin_loss.py
106 lines (89 loc) · 3.3 KB
/
margin_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from ..reducers import DivisorReducer
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction
class MarginLoss(BaseMetricLossFunction):
def __init__(
self,
margin=0.2,
nu=0,
beta=1.2,
triplets_per_anchor="all",
learn_beta=False,
num_classes=None,
**kwargs
):
super().__init__(**kwargs)
self.margin = margin
self.nu = nu
self.learn_beta = learn_beta
self.initialize_beta(beta, num_classes)
self.triplets_per_anchor = triplets_per_anchor
self.add_to_recordable_attributes(
list_of_names=["margin", "nu", "beta"], is_stat=False
)
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
indices_tuple = lmu.convert_to_triplets(
indices_tuple, labels, ref_labels, self.triplets_per_anchor
)
anchor_idx, positive_idx, negative_idx = indices_tuple
if len(anchor_idx) == 0:
return self.zero_losses()
beta = (
self.beta
if len(self.beta) == 1
else self.beta[labels[anchor_idx].to("cpu")]
) # When labels are on gpu gives error
beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype)
mat = self.distance(embeddings, ref_emb)
d_ap = mat[anchor_idx, positive_idx]
d_an = mat[anchor_idx, negative_idx]
pos_loss = torch.nn.functional.relu(
self.distance.margin(d_ap, beta) + self.margin
)
neg_loss = torch.nn.functional.relu(
self.distance.margin(beta, d_an) + self.margin
)
num_pos_pairs = torch.sum(pos_loss > 0.0)
num_neg_pairs = torch.sum(neg_loss > 0.0)
divisor = num_pos_pairs + num_neg_pairs
margin_loss = pos_loss + neg_loss
loss_dict = {
"margin_loss": {
"losses": margin_loss,
"indices": indices_tuple,
"reduction_type": "triplet",
"divisor": divisor,
},
"beta_reg_loss": self.compute_reg_loss(beta, anchor_idx, divisor),
}
return loss_dict
def compute_reg_loss(self, beta, anchor_idx, divisor):
if self.learn_beta:
loss = beta * self.nu
if len(self.beta) == 1:
return {
"losses": loss,
"indices": None,
"reduction_type": "already_reduced",
}
else:
return {
"losses": loss,
"indices": anchor_idx,
"reduction_type": "element",
"divisor": divisor,
}
return self.zero_loss()
def _sub_loss_names(self):
return ["margin_loss", "beta_reg_loss"]
def get_default_reducer(self):
return DivisorReducer()
def initialize_beta(self, beta, num_classes):
self.beta = torch.tensor([float(beta)])
if num_classes:
self.beta = torch.ones(num_classes) * self.beta
if self.learn_beta:
self.beta = torch.nn.Parameter(self.beta)