-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsiren.py
143 lines (116 loc) · 4.57 KB
/
siren.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Based on https://github.com/lucidrains/siren-pytorch
import torch
from torch import nn
from math import sqrt
class Sine(nn.Module):
"""Sine activation with scaling.
Args:
w0 (float): Omega_0 parameter from SIREN paper.
"""
def __init__(self, w0=1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
class SirenLayer(nn.Module):
"""Implements a single SIREN layer.
Args:
dim_in (int): Dimension of input.
dim_out (int): Dimension of output.
w0 (float):
c (float): c value from SIREN paper used for weight initialization.
is_first (bool): Whether this is first layer of model.
use_bias (bool):
activation (torch.nn.Module): Activation function. If None, defaults to
Sine activation.
"""
def __init__(self, dim_in, dim_out, w0=30., c=6., is_first=False,
use_bias=True, activation=None):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)
# Initialize layers following SIREN paper
w_std = (1 / dim_in) if self.is_first else (sqrt(c / dim_in) / w0)
nn.init.uniform_(self.linear.weight, -w_std, w_std)
if use_bias:
nn.init.uniform_(self.linear.bias, -w_std, w_std)
self.activation = Sine(w0) if activation is None else activation
def forward(self, x):
out = self.linear(x)
out = self.activation(out)
return out
class Siren(nn.Module):
"""SIREN model.
Args:
dim_in (int): Dimension of input.
dim_hidden (int): Dimension of hidden layers.
dim_out (int): Dimension of output.
num_layers (int): Number of layers.
w0 (float): Omega 0 from SIREN paper.
w0_initial (float): Omega 0 for first layer.
use_bias (bool):
final_activation (torch.nn.Module): Activation function.
"""
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0=30.,
w0_initial=30., use_bias=True, final_activation=None):
super().__init__()
layers = []
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
layers.append(SirenLayer(
dim_in=layer_dim_in,
dim_out=dim_hidden,
w0=layer_w0,
use_bias=use_bias,
is_first=is_first
))
self.net = nn.Sequential(*layers)
final_activation = nn.Identity() if final_activation is None else final_activation
self.last_layer = SirenLayer(dim_in=dim_hidden, dim_out=dim_out, w0=w0,
use_bias=use_bias, activation=final_activation)
def forward(self, x):
x = self.net(x)
return self.last_layer(x)
class MLP(nn.Module):
"""SIREN model.
Args:
dim_in (int): Dimension of input.
dim_hidden (int): Dimension of hidden layers.
dim_out (int): Dimension of output.
num_layers (int): Number of layers.
activation (torch.nn.Module): Activation function.
"""
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, activation=nn.ReLU(), siren_start=False, siren_end=False):
super(MLP, self).__init__()
self.fc_layers = nn.ModuleList()
if not siren_start:
self.fc_layers.extend([nn.Linear(dim_in, dim_hidden), activation])
else:
self.fc_layers.extend([
SirenLayer(
dim_in=dim_in,
dim_out=dim_hidden,
w0=30,
use_bias=True,
is_first=True
)])
for ind in range(num_layers - 1): # n_layers fully connected hidden layers with RELU transfer function
self.fc_layers.extend([nn.Linear(dim_hidden, dim_hidden), activation])
if siren_end:
self.fc_layers.extend([
SirenLayer(
dim_in=dim_hidden,
dim_out=dim_hidden,
w0=30,
use_bias=True,
is_first=False
)])
else:
self.fc_layers.extend([nn.Linear(dim_hidden, dim_hidden), activation])
self.encoder = nn.Sequential(*self.fc_layers, nn.Linear(dim_hidden, dim_out)) # Add the last linear layer for regression
def forward(self, x):
x = self.encoder(x)
return x