-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathloss.py
104 lines (79 loc) · 3.79 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn
import math
class contrastive_loss(nn.Module):
def __init__(self, tau=1, normalize=False):
super(contrastive_loss, self).__init__()
self.tau = tau
self.normalize = normalize
def forward(self, xi, xj):
x = torch.cat((xi, xj), dim=0) # [128, 256]
is_cuda = x.is_cuda
sim_mat = torch.mm(x, x.T)
if self.normalize: # False
sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)
sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16)
sim_mat = torch.exp(sim_mat / self.tau)
# no diag because it's not diffrentiable -> sum - exp(1 / tau)
# diag_ind = torch.eye(xi.size(0) * 2).bool()
# diag_ind = diag_ind.cuda() if use_cuda else diag_ind
# sim_mat = sim_mat.masked_fill_(diag_ind, 0)
# top
if self.normalize:
sim_mat_denom = torch.norm(xi, dim=1) * torch.norm(xj, dim=1)
sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / sim_mat_denom / self.tau)
else:
sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / self.tau)
sim_match = torch.cat((sim_match, sim_match), dim=0)
norm_sum = torch.exp(torch.ones(x.size(0)) / self.tau)
norm_sum = norm_sum.cuda() if is_cuda else norm_sum
loss = torch.mean(-torch.log(sim_match / (torch.sum(sim_mat, dim=-1) - norm_sum)))
return loss
def matrix_log_density_gaussian(x, mu, logvar):
batch_size, dim = x.shape
x = x.view(batch_size, 1, dim)
mu = mu.view(1, batch_size, dim)
logvar = logvar.view(1, batch_size, dim)
return log_density_gaussian(x, mu, logvar)
def log_density_gaussian(x, mu, logvar):
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
inv_var = torch.exp(-logvar)
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
return log_density
def log_importance_weight_matrix(batch_size, dataset_size):
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M + 1] = 1 / N
W.view(-1)[1::M + 1] = strat_weight
W[M - 1, 0] = strat_weight
return W.log()
def compute_mi(latent_sample, latent_dist):
log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
latent_dist,
None,
is_mss=False)
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
mi_loss = (log_q_zCx - log_qz).mean()
return mi_loss
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
batch_size, hidden_dim = latent_sample.shape
#print("latent_sample:", latent_sample.shape)
#print("latent_dist:", len(latent_dist), latent_dist[0].shape, latent_dist[1].shape)
#print("is_mss:", is_mss)
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
#zeros = torch.zeros_like(latent_sample)
#log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
if is_mss:
# use stratification
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
#log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
#return log_pz, log_qz, log_prod_qzi, log_q_zCx
return None, log_qz, None, log_q_zCx