Skip to content

Commit b199a41

Browse files
match values to command args
1 parent 12d7828 commit b199a41

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

results/sample_5.png

-958 Bytes
Loading

vae.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,23 @@ def vae_loss_function(self, recon_x, x):
9191
parser.add_argument('--epochs', type=int, default=10, metavar='N',
9292
help='number of epochs to train (default: 10)')
9393
parser.add_argument('--no-cuda', action='store_true', default=False,
94-
help='enables CUDA training')
94+
help='disables CUDA training')
9595
parser.add_argument('--emb-size', type=int, default=10, help='size of embedding (default 10)')
9696
args = parser.parse_args()
97+
9798
args.cuda = not args.no_cuda and torch.cuda.is_available()
99+
device = torch.device("cuda" if args.cuda else "cpu")
98100

99101
torch.manual_seed(42)
100102

101-
device = torch.device("cuda" if args.cuda else "cpu")
102-
103103
mnist = untar_data(URLs.MNIST_TINY)
104104
tfms = get_transforms(do_flip=False)
105105

106106
data = (ImageImageList.from_folder(mnist/'train')
107107
.split_by_rand_pct(0.1, seed=42)
108108
.label_from_func(lambda x: x)
109109
.transform(tfms)
110-
.databunch(num_workers=0, bs=16)
110+
.databunch(num_workers=0, bs=args.batch_size)
111111
.normalize(do_y=True))
112112

113113
image_size = data.one_batch()[0].shape[-1]
@@ -120,6 +120,6 @@ def vae_loss_function(self, recon_x, x):
120120
my_learner.show_results(rows=4)
121121

122122
print(f'Sampling 64 values and saving reconstruction. ')
123-
sample = torch.randn(64, 2)
123+
sample = torch.randn(64, args.emb_size)
124124
sample = vae.decode(sample).cpu()
125125
save_image(sample.view(64, 3, 28, 28), f'results/sample_{str(args.epochs)}.png')

0 commit comments

Comments
 (0)