@@ -880,7 +880,7 @@ def visualize_grid2(imgs, label, normalize=True):
880
880
# https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
881
881
# https://www.youtube.com/watch?v=uaaqyVS9-rM
882
882
# http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/
883
-
883
+ # https://www.reddit.com/r/MLQuestions/comments/dl7mya/a_few_more_questions_about_vaes/
884
884
885
885
# we'll also have an example concerning words(in NLP domain) and see how we can
886
886
# leverage VAEs in that domain as well. for now lets see how we can implement this
@@ -895,7 +895,20 @@ def visualize_grid2(imgs, label, normalize=True):
895
895
# will help you grasp one aspect very good!
896
896
#
897
897
# now lets define our VAE model .
898
+
899
+
898
900
class VAE (nn .Module ):
901
+
902
+
903
+ def conv (self , in_dim , out_dim , k_size = 3 , stride = 2 , padding = 1 , batch_norm = True , bias = False ):
904
+ return nn .Sequential (nn .Conv2d (in_dim , out_dim , k_size , stride , padding , bias = bias ),
905
+ nn .BatchNorm2d (out_dim ) if batch_norm else nn .Identity (),
906
+ nn .ReLU ())
907
+
908
+ def deconv (self , in_dim , out_dim , k_size = 3 , stride = 2 , padding = 1 , batch_norm = True , bias = False ):
909
+ return nn .Sequential (nn .ConvTranspose2d (in_dim , out_dim , k_size , stride , padding , bias = bias ),
910
+ nn .BatchNorm2d (out_dim ) if batch_norm else nn .Identity (),
911
+ nn .ReLU ())
899
912
def __init__ (self , embedding_size = 100 ):
900
913
super ().__init__ ()
901
914
@@ -913,11 +926,26 @@ def __init__(self, embedding_size=100):
913
926
# density.
914
927
# We can sample from this distribution to get noisy values of the
915
928
# representations z .
916
-
929
+
917
930
self .fc1 = nn .Linear (28 * 28 , 512 )
918
- self .fc1_mu = nn .Linear (512 , self .embedding_size ) # mean
931
+ self .encoder = nn .Sequential (self .conv (3 ,768 ),
932
+ self .conv (768 ,512 ),
933
+ self .conv (512 ,256 ),
934
+ nn .MaxPool2d (2 ,2 ),#16
935
+ self .conv (256 ,128 ),
936
+ self .conv (128 ,64 ),
937
+ nn .MaxPool2d (2 ,2 ),#8
938
+ self .conv (64 , 32 ),
939
+ nn .MaxPool2d (2 ,2 ),#4
940
+ self .conv (32 , 16 ),
941
+ nn .MaxPool2d (2 ,2 ),#2x2
942
+ self .conv (16 , 8 ),
943
+ nn .MaxPool2d (2 ,2 ),#1x1
944
+ )
945
+
946
+ self .fc1_mu = nn .Linear (8 , self .embedding_size ) # mean
919
947
# we use log since we want to prevent getting negative variance
920
- self .fc1_std = nn .Linear (512 , self .embedding_size ) #logvariance
948
+ self .fc1_std = nn .Linear (8 , self .embedding_size ) #logvariance
921
949
922
950
# our decoder will accept a randomly sampled vector using
923
951
# our mu and std.
@@ -939,14 +967,23 @@ def __init__(self, embedding_size=100):
939
967
# log-likelihood logpϕ(x∣z) whose units are nats. This measure tells us how
940
968
# effectively the decoder has learned to reconstruct an input image x given
941
969
# its latent representation z.
942
- self .decoder = nn .Sequential ( nn .Linear (self .embedding_size , 512 ),
943
- nn .ReLU (),
944
- nn .Linear (512 , 28 * 28 ),
945
- # in normal situations we wouldnt use sigmoid
946
- # but since we want our values to be in [0,1]
947
- # we use sigmoid. for loss we will then have
948
- # to use, plain BCE (and specifically not BCEWithLogits)
949
- nn .Sigmoid ())
970
+ self .decoder = nn .Sequential (nn .Linear (self .embedding_size , 8 * 1 * 1 ),
971
+ deconv (8 , 768 ,kernel_size = 4 ,stride = 2 ),
972
+ deconv (768 ,512 ,kernel_size = 4 ,stride = 2 ),
973
+ deconv (512 , 256 ,kernel_size = 4 ,stride = 2 ),
974
+ deconv (256 ,128 ,kernel_size = 4 ,stride = 2 ),
975
+ deconv (128 ,3 ,kernel_size = 4 ,stride = 2 ),
976
+ # deconv(64,32,kernel_size=4,stride=2),
977
+ # deconv(32,3,kernel_size=4,stride=2),
978
+ nn .Sigmoid ())
979
+ # self.decoder = nn.Sequential( nn.Linear(self.embedding_size, 512),
980
+ # nn.ReLU(),
981
+ # nn.Linear(512, 28*28),
982
+ # # in normal situations we wouldnt use sigmoid
983
+ # # but since we want our values to be in [0,1]
984
+ # # we use sigmoid. for loss we will then have
985
+ # # to use, plain BCE (and specifically not BCEWithLogits)
986
+ # nn.Sigmoid())
950
987
951
988
952
989
@@ -1029,8 +1066,9 @@ def reparamtrization_trick(self, mu, logvar):
1029
1066
return mu + eps * std
1030
1067
#
1031
1068
def encode (self , input ):
1032
- input = input .view (input .size (0 ), - 1 )
1033
- output = F .relu (self .fc1 (input ))
1069
+ # input = input.view(input.size(0), -1)
1070
+ # output = F.relu(self.fc1(input))
1071
+ output = self .encoder (input )
1034
1072
# we dont use activations here
1035
1073
mu = self .fc1_mu (output )
1036
1074
log_var = self .fc1_std (output )
@@ -1186,6 +1224,11 @@ def loss_function(outputs, inputs, mu, logvar, reduction ='mean', use_mse = Fals
1186
1224
#%%
1187
1225
# now lets train :
1188
1226
epochs = 50
1227
+ dataset_train = datasets .CIFAR10 ('cifar10' , train = True , download = True ,transform = transforms .ToTensor ())
1228
+ dataset_test = datasets .CIFAR10 ('cifar10' , train = False , download = True ,transform = transforms .ToTensor ())
1229
+
1230
+ dataloader_train = torch .utils .data .DataLoader (dataset_train ,batch_size = 128 ,shuffle = True )
1231
+ dataloader_test = torch .utils .data .DataLoader (dataset_test ,batch_size = 128 ,shuffle = False )
1189
1232
1190
1233
embeddingsize = 2
1191
1234
interval = 2000
0 commit comments