@@ -91,23 +91,23 @@ def vae_loss_function(self, recon_x, x):
91
91
parser .add_argument ('--epochs' , type = int , default = 10 , metavar = 'N' ,
92
92
help = 'number of epochs to train (default: 10)' )
93
93
parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
94
- help = 'enables CUDA training' )
94
+ help = 'disables CUDA training' )
95
95
parser .add_argument ('--emb-size' , type = int , default = 10 , help = 'size of embedding (default 10)' )
96
96
args = parser .parse_args ()
97
+
97
98
args .cuda = not args .no_cuda and torch .cuda .is_available ()
99
+ device = torch .device ("cuda" if args .cuda else "cpu" )
98
100
99
101
torch .manual_seed (42 )
100
102
101
- device = torch .device ("cuda" if args .cuda else "cpu" )
102
-
103
103
mnist = untar_data (URLs .MNIST_TINY )
104
104
tfms = get_transforms (do_flip = False )
105
105
106
106
data = (ImageImageList .from_folder (mnist / 'train' )
107
107
.split_by_rand_pct (0.1 , seed = 42 )
108
108
.label_from_func (lambda x : x )
109
109
.transform (tfms )
110
- .databunch (num_workers = 0 , bs = 16 )
110
+ .databunch (num_workers = 0 , bs = args . batch_size )
111
111
.normalize (do_y = True ))
112
112
113
113
image_size = data .one_batch ()[0 ].shape [- 1 ]
@@ -120,6 +120,6 @@ def vae_loss_function(self, recon_x, x):
120
120
my_learner .show_results (rows = 4 )
121
121
122
122
print (f'Sampling 64 values and saving reconstruction. ' )
123
- sample = torch .randn (64 , 2 )
123
+ sample = torch .randn (64 , args . emb_size )
124
124
sample = vae .decode (sample ).cpu ()
125
125
save_image (sample .view (64 , 3 , 28 , 28 ), f'results/sample_{ str (args .epochs )} .png' )
0 commit comments