@@ -39,7 +39,7 @@ def forward(self, x):
39
39
class CNNEncoder (nn .Module ):
40
40
def __init__ (self , sizes ):
41
41
super ().__init__ ()
42
- self .out_seq = nn .Sequential (* [PoolingDownsampleBlock (size_in , size_out ) for size_in , size_out
42
+ self .out_seq = nn .Sequential (* [DownsampleBlock (size_in , size_out ) for size_in , size_out
43
43
in zip (sizes [0 :- 1 ], sizes [1 :])])
44
44
45
45
def forward (self , x ):
@@ -52,11 +52,11 @@ def __init__(self, sizes):
52
52
super ().__init__ ()
53
53
sizes = list (reversed (sizes ))
54
54
sizes_minus_last = sizes [0 :- 1 ]
55
- self .in_seq = nn .Sequential (* [UnPoolingUpsampleBlock (size_in , size_out , "relu" ) for size_in , size_out
55
+ self .in_seq = nn .Sequential (* [UpsampleBlock (size_in , size_out , "relu" ) for size_in , size_out
56
56
in zip (sizes_minus_last [0 :- 1 ], sizes_minus_last [1 :])])
57
57
58
- self .last = UnPoolingUpsampleBlock (
59
- sizes [- 2 ], sizes [- 1 ], activation = "relu " )
58
+ self .last = UpsampleBlock (
59
+ sizes [- 2 ], sizes [- 1 ], activation = "sigmoid " )
60
60
61
61
def forward (self , x ):
62
62
x = self .in_seq (x )
@@ -69,11 +69,14 @@ def __init__(self, size_in, size_out):
69
69
super ().__init__ ()
70
70
# Modify this to create new conv blocks
71
71
# Eg: Throw in pooling, throw in residual connections ... whatever you want
72
- self .conv_1 = nn .Conv2d (size_in , size_out , 3 , padding = "valid" )
72
+ self .conv_1 = nn .Conv2d (
73
+ size_in , size_out , kernel_size = 3 , stride = 2 , padding = 1 )
74
+ self .bn_1 = nn .BatchNorm2d (size_out )
73
75
self .act = nn .ReLU ()
74
76
75
77
def forward (self , x ):
76
78
x = self .conv_1 (x )
79
+ x = self .bn_1 (x )
77
80
return self .act (x )
78
81
79
82
@@ -82,47 +85,18 @@ def __init__(self, size_in, size_out, activation):
82
85
super ().__init__ ()
83
86
# Modify this to create new transpose conv blocks
84
87
# Eg: Throw in dropout, throw in batchnorm ... whatvever you want
85
- self .up_conv_1 = nn .ConvTranspose2d (size_in , size_out , 3 )
88
+ self .up_conv_1 = nn .ConvTranspose2d (
89
+ size_in , size_out , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 )
86
90
activations = nn .ModuleDict ([
87
91
["relu" , nn .ReLU ()],
88
92
["sigmoid" , nn .Sigmoid ()],
89
93
["tanh" , nn .Tanh ()]
90
94
])
95
+ self .bn_1 = nn .BatchNorm2d (size_out )
96
+
91
97
self .act = activations [activation ]
92
98
93
99
def forward (self , x ):
94
100
x = self .up_conv_1 (x )
101
+ x = self .bn_1 (x )
95
102
return self .act (x )
96
-
97
- class PoolingDownsampleBlock (nn .Module ):
98
- def __init__ (self , size_in , size_out ):
99
- super ().__init__ ()
100
- # Modify this to create new conv blocks
101
- # Eg: Throw in pooling, throw in residual connections ... whatever you want
102
- self .conv_1 = nn .Conv2d (size_in , size_out , 3 , padding = "valid" )
103
- self .pool = nn .Conv2d (size_out , size_out , 3 , padding = "valid" )
104
- #self.pool = nn.MaxPool2d(3, 1)
105
- self .act = nn .ReLU ()
106
- def forward (self , x ):
107
- x = self .conv_1 (x )
108
- x = self .pool (x )
109
- return self .act (x )
110
-
111
- class UnPoolingUpsampleBlock (nn .Module ):
112
- def __init__ (self , size_in , size_out , activation ):
113
- super ().__init__ ()
114
- # Modify this to create new transpose conv blocks
115
- # Eg: Throw in dropout, throw in batchnorm ... whatvever you want
116
- self .up_conv_1 = nn .ConvTranspose2d (size_in , size_out , 3 )
117
- self .up_conv_2 = nn .ConvTranspose2d (size_out , size_out , 3 )
118
-
119
- activations = nn .ModuleDict ([
120
- ["relu" , nn .ReLU ()],
121
- ["sigmoid" , nn .Sigmoid ()],
122
- ["tanh" , nn .Tanh ()]
123
- ])
124
- self .act = activations [activation ]
125
- def forward (self , x ):
126
- x = self .up_conv_1 (x )
127
- x = self .up_conv_2 (x )
128
- return self .act (x )
0 commit comments