Skip to content

Commit 8ecc20e

Browse files
committed
Fixing a error in the log_alpha computation
1 parent 616d955 commit 8ecc20e

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

Sparse/modules/variational/LinearSparseVariationalDropout.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@
66

77
from .VariationalLayer import VariationalLayer
88

9-
def compute_log_alpha(log_sigma, theta):
10-
r'''
11-
Compute the log \alpha values from \theta and log \sigma^2.
12-
13-
The relationship between \sigma^2, \theta, and \alpha as defined in the
14-
paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2.
15-
16-
This method calculates the log \alpha values based on this relation:
17-
\log(\alpha) = 2*\log(\sigma) - 2*\log(\theta)
18-
'''
19-
log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta))
20-
log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability
21-
return log_alpha
22-
239
# Linear Sparse Variational Dropout
2410
# See https://arxiv.org/pdf/1701.05369.pdf for details
2511
class LinearSVD(nn.Linear, VariationalLayer):
@@ -50,18 +36,32 @@ def __init__(self, in_features, out_features, p_threshold = 0.952572, bias=True)
5036
self.log_sigma.data.fill_(-5) # Initialization based on the paper, Figure 1
5137

5238
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
self.log_alpha = self.compute_log_alpha(self.log_sigma, torch.abs(self.weight))
40+
5341
if self.training:
5442
# LRT = local reparametrization trick (For details, see https://arxiv.org/pdf/1506.02557.pdf)
5543
lrt_mean = F.linear(x, self.weight, self.bias)
5644
lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
5745
eps = torch.normal(0, torch.ones_like(lrt_std))
5846
return lrt_mean + lrt_std * eps
5947

60-
self.log_alpha = compute_log_alpha(self.log_sigma, torch.abs(self.weight))
6148
return F.linear(x, self.weight * (self.log_alpha < self.log_alpha_threshold).float(), self.bias)
6249

6350
def kl_reg(self):
64-
k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
6551
k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
6652
kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
6753
return -(torch.sum(kl))
54+
55+
def compute_log_alpha(self, log_sigma, theta):
56+
r'''
57+
Compute the log \alpha values from \theta and log \sigma^2.
58+
59+
The relationship between \sigma^2, \theta, and \alpha as defined in the
60+
paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2.
61+
62+
This method calculates the log \alpha values based on this relation:
63+
\log(\alpha) = 2*\log(\sigma) - 2*\log(\theta)
64+
'''
65+
log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta))
66+
log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability
67+
return log_alpha

0 commit comments

Comments
 (0)