|
6 | 6 |
|
7 | 7 | from .VariationalLayer import VariationalLayer
|
8 | 8 |
|
9 |
| -def compute_log_alpha(log_sigma, theta): |
10 |
| - r''' |
11 |
| - Compute the log \alpha values from \theta and log \sigma^2. |
12 |
| -
|
13 |
| - The relationship between \sigma^2, \theta, and \alpha as defined in the |
14 |
| - paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2. |
15 |
| -
|
16 |
| - This method calculates the log \alpha values based on this relation: |
17 |
| - \log(\alpha) = 2*\log(\sigma) - 2*\log(\theta) |
18 |
| - ''' |
19 |
| - log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta)) |
20 |
| - log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability |
21 |
| - return log_alpha |
22 |
| - |
23 | 9 | # Linear Sparse Variational Dropout
|
24 | 10 | # See https://arxiv.org/pdf/1701.05369.pdf for details
|
25 | 11 | class LinearSVD(nn.Linear, VariationalLayer):
|
@@ -50,18 +36,32 @@ def __init__(self, in_features, out_features, p_threshold = 0.952572, bias=True)
|
50 | 36 | self.log_sigma.data.fill_(-5) # Initialization based on the paper, Figure 1
|
51 | 37 |
|
52 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 | + self.log_alpha = self.compute_log_alpha(self.log_sigma, torch.abs(self.weight)) |
| 40 | + |
53 | 41 | if self.training:
|
54 | 42 | # LRT = local reparametrization trick (For details, see https://arxiv.org/pdf/1506.02557.pdf)
|
55 | 43 | lrt_mean = F.linear(x, self.weight, self.bias)
|
56 | 44 | lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
|
57 | 45 | eps = torch.normal(0, torch.ones_like(lrt_std))
|
58 | 46 | return lrt_mean + lrt_std * eps
|
59 | 47 |
|
60 |
| - self.log_alpha = compute_log_alpha(self.log_sigma, torch.abs(self.weight)) |
61 | 48 | return F.linear(x, self.weight * (self.log_alpha < self.log_alpha_threshold).float(), self.bias)
|
62 | 49 |
|
63 | 50 | def kl_reg(self):
|
64 |
| - k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda() |
65 | 51 | k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
|
66 | 52 | kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
|
67 | 53 | return -(torch.sum(kl))
|
| 54 | + |
| 55 | + def compute_log_alpha(self, log_sigma, theta): |
| 56 | + r''' |
| 57 | + Compute the log \alpha values from \theta and log \sigma^2. |
| 58 | +
|
| 59 | + The relationship between \sigma^2, \theta, and \alpha as defined in the |
| 60 | + paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2. |
| 61 | +
|
| 62 | + This method calculates the log \alpha values based on this relation: |
| 63 | + \log(\alpha) = 2*\log(\sigma) - 2*\log(\theta) |
| 64 | + ''' |
| 65 | + log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta)) |
| 66 | + log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability |
| 67 | + return log_alpha |
0 commit comments