-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmutual_info.py
35 lines (31 loc) · 1.12 KB
/
mutual_info.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import torch
def logsumexp(value, dim=None, keepdim=False):
"""Numerically stable implementation of the operation
value.exp().sum(dim, keepdim).log()
"""
if dim is not None:
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0),
dim=dim, keepdim=keepdim))
else:
raise ValueError('Must specify the dimension.')
def log_density(sample, mu, logsigma):
mu = mu.type_as(sample)
logsigma = logsigma.type_as(sample)
c = torch.Tensor([np.log(2 * np.pi)]).type_as(sample.data)
inv_sigma = torch.exp(-logsigma)
tmp = (sample - mu) * inv_sigma
return -0.5 * (tmp * tmp + 2 * logsigma + c)
def log_importance_weight_matrix(batch_size, dataset_size):
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M+1] = 1 / N
W.view(-1)[1::M+1] = strat_weight
W[M-1, 0] = strat_weight
return W.log()