-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
234 lines (172 loc) · 8.04 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
This module is for contructing loss function.
"""
import torch
import math
def calculate_distance_and_similariy_label(features, features_, labels, labels_, sqrt=True, pair_type='vector'):
"""
The calculate is based on following equations
X: (N, M)
Y: (P, M)
Each row represents one sample.
the pairwise distance between X and Y is formulated as
TO BE CONTINUED.
Args:
features: (N, M)
features_: (N, M)
labels: (N,)
labels_: (N,)
pair_type: str
"vector": generating N pairs
"matrix": generating N^2 pairs
Returns:
pairwise_distances: (N,) for "vector", (N, N) for "matrix"
pairwise_similarity_labels: (N,) for "vector", (N, N) for "matrix"
"""
def get_squared_features(features):
"""
elementwised operation.
"""
features_l2_norm = torch.sum(torch.pow(features, 2), dim=1, keepdim=True)
return features_l2_norm
# reshape label for convenience
if pair_type is None or pair_type == 'matrix':
labels = labels.view(-1, 1)
labels_ = labels_.view(-1, 1)
# calcualte pairwise distance
squared_features = get_squared_features(features)
squared_features_ = get_squared_features(features_).permute(1, 0)
correlation_term = torch.mm(features, features_.permute(1, 0))
pairwise_distances = squared_features + squared_features_ - 2. * correlation_term
# calcualte pairwise similarity labels
num_labels = labels.size(0)
num_labels_ = labels_.size(0)
tiled_labels = labels.repeat(1, num_labels_)
tiled_labels_ = labels_.repeat(num_labels, 1)
# pairwise_similarity_labels = torch.eq(tiled_labels.view(-1), tiled_labels_.view(-1)).type(torch.cuda.FloatTensor)
pairwise_similarity_labels = torch.eq(tiled_labels.view(-1), tiled_labels_.view(-1))
pairwise_similarity_labels = pairwise_similarity_labels.view(num_labels, num_labels_)
elif pair_type == 'vector':
pairwise_distances = torch.sum(torch.pow(features-features_, 2), dim=1)
pairwise_similarity_labels = torch.eq(labels, labels_)
if sqrt:
# return the sqrt(distance)
pairwise_distances = torch.sqrt(pairwise_distances + 1e-8) # To make sure the training is stable.
return pairwise_distances, pairwise_similarity_labels.type(torch.cuda.FloatTensor)
"""
def linear_transform(x, mean_value=0.5, std_value=2.):
y = (x - mean_value) / std_value
return y
"""
def contrastive_loss(pairwise_distances, pairwise_similarity_labels, margin=1):
"""
formulate constrastive loss.
"""
# positive pair loss
positive_pair_loss = torch.pow(pairwise_distances, 2) * pairwise_similarity_labels
positive_pair_loss = torch.mean(positive_pair_loss)
# negative pair loss
negative_pair_loss = (1. - pairwise_similarity_labels) * torch.pow(torch.clamp(margin - pairwise_distances, 0.0), 2)
negative_pair_loss = torch.mean(negative_pair_loss)
loss = positive_pair_loss + negative_pair_loss
return loss, positive_pair_loss, negative_pair_loss
def triplet_loss(pairwise_distances, pairwise_similarity_labels, margin=1):
"""Create the triplet loss."""
anchor_positive_distance = torch.unsqueeze(pairwise_distances, 2)
anchor_negative_distance = torch.unsqueeze(pairwise_distances, 1)
triplet_loss = margin + anchor_positive_distance - anchor_negative_distance
i_equal_j = torch.unsqueeze(pairwise_similarity_labels, 2)
j_equal_k = torch.unsqueeze(pairwise_similarity_labels, 1)
mask = i_equal_j * (1 - j_equal_k)
effective_pair_num = torch.sum(mask)
# hinge loss
triplet_loss = torch.clamp(triplet_loss, 0.)
# apply mask
triplet_loss = triplet_loss * mask
# average loss
# triplet_loss = torch.sum(triplet_loss) / effective_pair_num
triplet_loss = 3.0 * torch.mean(triplet_loss)
return triplet_loss, triplet_loss, triplet_loss
def focal_triplet_loss(pairwise_distances, pairwise_similarity_labels, margin=1, mean_value=0.5, std_value=2.0):
def linear_transform(x, mean_value=0.5, std_value=2.):
y = (x - mean_value) / std_value
return y
anchor_positive_distance = torch.unsqueeze(pairwise_distances, 2)
anchor_negative_distance = torch.unsqueeze(pairwise_distances, 1)
triplet_loss = margin + anchor_positive_distance - anchor_negative_distance
i_equal_j = torch.unsqueeze(pairwise_similarity_labels, 2)
j_equal_k = torch.unsqueeze(pairwise_similarity_labels, 1)
mask = i_equal_j * (1 - j_equal_k)
effective_pair_num = torch.sum(mask)
# hinge loss
triplet_loss = torch.clamp(triplet_loss, 0.)
# Get focal factor
factor = 2. * torch.sigmoid(linear_transform(triplet_loss, mean_value, std_value))
# apply mask and factor
triplet_loss = factor * triplet_loss * mask
# average loss
# triplet_loss = torch.sum(triplet_loss) / effective_pair_num
triplet_loss = 3.0 * torch.mean(triplet_loss)
return triplet_loss, triplet_loss, triplet_loss
def focal_contrastive_loss(pairwise_distances, pairwise_similarity_labels, margin, mean_value=0.5, std_value=2.):
def linear_transform(x, mean_value=0.5, std_value=2.):
y = (x - mean_value) / std_value
return y
positive_factor = 2. * torch.sigmoid(linear_transform(pairwise_distances, mean_value, std_value))
positive_pair_loss = torch.mean(positive_factor * torch.pow(pairwise_distances, 2) * pairwise_similarity_labels)
negative_distances = torch.clamp(margin-pairwise_distances, 0.0)
negative_factor = 2. * torch.sigmoid(linear_transform(negative_distances, mean_value, std_value))
negative_pair_loss = torch.mean(negative_factor * (1. - pairwise_similarity_labels) *\
torch.pow(negative_distances, 2))
loss = positive_pair_loss + negative_pair_loss
return loss, positive_pair_loss, negative_pair_loss
def triplet_loss(pairwise_distances, pairwise_similarity_labels, margin=1):
"""Create the triplet loss."""
anchor_positive_distance = torch.unsqueeze(pairwise_distances, 2)
anchor_negative_distance = torch.unsqueeze(pairwise_distances, 1)
triplet_loss = margin + anchor_positive_distance - anchor_negative_distance
i_equal_j = torch.unsqueeze(pairwise_similarity_labels, 2)
j_equal_k = torch.unsqueeze(pairwise_similarity_labels, 1)
mask = i_equal_j * (1 - j_equal_k)
effective_pair_num = torch.sum(mask)
# hinge loss
triplet_loss = torch.clamp(triplet_loss, 0.)
# apply mask
triplet_loss = triplet_loss * mask
# average loss
# triplet_loss = torch.sum(triplet_loss) / effective_pair_num
triplet_loss = 3.0 * torch.mean(triplet_loss)
return triplet_loss, triplet_loss, triplet_loss
def angular_loss(ap_distances, cn_distances, pairwise_similarity_labels, alpha):
"""
Args:
ap_distances: anchor and positive example distances.
cn_distances: center and negative example distances.
...
"""
coeff = 4 * math.tan(alpha/180.0*math.pi)**2
ap_distances = torch.unsqueeze(ap_distances, 2)
cn_distances = torch.unsqueeze(cn_distances, 1)
batch_size = 64
total_num = batch_size ** 3
i_equal_j = torch.unsqueeze(pairwise_similarity_labels, 2)
j_equal_k = torch.unsqueeze(pairwise_similarity_labels, 1)
mask = i_equal_j * (1 - j_equal_k)
loss = torch.pow(ap_distances, 2) - coeff*torch.pow(cn_distances, 2)
loss = torch.clamp(loss, 0)
loss = loss * mask
loss = 10 * torch.mean(loss)
"""
loss = torch.log(1 + torch.sum(mask * torch.exp(
torch.pow(ap_distances, 2) - coeff*torch.pow(cn_distances, 2))))
"""
return loss, loss, loss
if __name__ == '__main__':
feature_1 = torch.randn(10, 5)
label_1 = torch.randn(10)
feature_2 = torch.randn(12, 5)
label_2 = torch.randn(12)
pair_dist, pair_sim_label = calculate_distance_and_similariy_label(feature_1, feature_2, label_1, label_2)
print(pair_dist.size(), pair_sim_label.size())
loss = contrastive_loss(pair_dist, pair_sim_label, 1.)
print(loss.item())