Skip to content

Commit 2d6fc52

Browse files
author
Jaan Altosaar
committed
format and lint with flake8 and black
1 parent c7b298e commit 2d6fc52

File tree

3 files changed

+306
-288
lines changed

3 files changed

+306
-288
lines changed

README.md

+7-30
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,16 @@ Variational inference is used to fit the model to binarized MNIST handwritten di
99

1010
Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
1111

12-
Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Finaly marginal likelihood on the test set of `-97.10` nats.
12+
Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Final marginal likelihood on the test set was `-97.10` nats after 65k iterations.
1313

1414
```
1515
$ python train_variational_autoencoder_pytorch.py --variational mean-field
16-
step: 0 train elbo: -558.69
17-
step: 0 valid elbo: -391.84 valid log p(x): -363.25
18-
step: 5000 train elbo: -116.09
19-
step: 5000 valid elbo: -112.57 valid log p(x): -107.01
20-
step: 10000 train elbo: -105.82
21-
step: 10000 valid elbo: -108.49 valid log p(x): -102.62
22-
step: 15000 train elbo: -106.78
23-
step: 15000 valid elbo: -106.97 valid log p(x): -100.97
24-
step: 20000 train elbo: -108.43
25-
step: 20000 valid elbo: -106.23 valid log p(x): -100.04
26-
step: 25000 train elbo: -99.68
27-
step: 25000 valid elbo: -104.89 valid log p(x): -98.83
28-
step: 30000 train elbo: -96.71
29-
step: 30000 valid elbo: -104.50 valid log p(x): -98.34
30-
step: 35000 train elbo: -98.64
31-
step: 35000 valid elbo: -104.05 valid log p(x): -97.87
32-
step: 40000 train elbo: -93.60
33-
step: 40000 valid elbo: -104.10 valid log p(x): -97.68
34-
step: 45000 train elbo: -96.45
35-
step: 45000 valid elbo: -104.58 valid log p(x): -97.76
36-
step: 50000 train elbo: -101.63
37-
step: 50000 valid elbo: -104.72 valid log p(x): -97.81
38-
step: 55000 train elbo: -106.78
39-
step: 55000 valid elbo: -105.14 valid log p(x): -98.06
40-
step: 60000 train elbo: -100.58
41-
step: 60000 valid elbo: -104.13 valid log p(x): -97.30
42-
step: 65000 train elbo: -96.19
43-
step: 65000 valid elbo: -104.46 valid log p(x): -97.43
44-
step: 65000 test elbo: -103.31 test log p(x): -97.10
16+
step: 0 train elbo: -558.28
17+
step: 0 valid elbo: -392.78 valid log p(x): -359.91
18+
step: 10000 train elbo: -106.67
19+
step: 10000 valid elbo: -109.12 valid log p(x): -103.11
20+
step: 20000 train elbo: -107.28
21+
step: 20000 valid elbo: -105.65 valid log p(x): -99.74
4522
```
4623

4724

data.py

+54-30
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,63 @@
55
import os
66
import numpy as np
77
import h5py
8+
import torch
89

910

1011
def parse_binary_mnist(data_dir):
11-
def lines_to_np_array(lines):
12-
return np.array([[int(i) for i in line.split()] for line in lines])
13-
with open(os.path.join(data_dir, 'binarized_mnist_train.amat')) as f:
14-
lines = f.readlines()
15-
train_data = lines_to_np_array(lines).astype('float32')
16-
with open(os.path.join(data_dir, 'binarized_mnist_valid.amat')) as f:
17-
lines = f.readlines()
18-
validation_data = lines_to_np_array(lines).astype('float32')
19-
with open(os.path.join(data_dir, 'binarized_mnist_test.amat')) as f:
20-
lines = f.readlines()
21-
test_data = lines_to_np_array(lines).astype('float32')
22-
return train_data, validation_data, test_data
12+
def lines_to_np_array(lines):
13+
return np.array([[int(i) for i in line.split()] for line in lines])
14+
15+
with open(os.path.join(data_dir, "binarized_mnist_train.amat")) as f:
16+
lines = f.readlines()
17+
train_data = lines_to_np_array(lines).astype("float32")
18+
with open(os.path.join(data_dir, "binarized_mnist_valid.amat")) as f:
19+
lines = f.readlines()
20+
validation_data = lines_to_np_array(lines).astype("float32")
21+
with open(os.path.join(data_dir, "binarized_mnist_test.amat")) as f:
22+
lines = f.readlines()
23+
test_data = lines_to_np_array(lines).astype("float32")
24+
return train_data, validation_data, test_data
2325

2426

2527
def download_binary_mnist(fname):
26-
data_dir = '/tmp/'
27-
subdatasets = ['train', 'valid', 'test']
28-
for subdataset in subdatasets:
29-
filename = 'binarized_mnist_{}.amat'.format(subdataset)
30-
url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(
31-
subdataset)
32-
local_filename = os.path.join(data_dir, filename)
33-
urllib.request.urlretrieve(url, local_filename)
34-
35-
train, validation, test = parse_binary_mnist(data_dir)
36-
37-
data_dict = {'train': train, 'valid': validation, 'test': test}
38-
f = h5py.File(fname, 'w')
39-
f.create_dataset('train', data=data_dict['train'])
40-
f.create_dataset('valid', data=data_dict['valid'])
41-
f.create_dataset('test', data=data_dict['test'])
42-
f.close()
43-
print(f'Saved binary MNIST data to: {fname}')
28+
data_dir = "/tmp/"
29+
subdatasets = ["train", "valid", "test"]
30+
for subdataset in subdatasets:
31+
filename = "binarized_mnist_{}.amat".format(subdataset)
32+
url = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat".format(
33+
subdataset
34+
)
35+
local_filename = os.path.join(data_dir, filename)
36+
urllib.request.urlretrieve(url, local_filename)
37+
38+
train, validation, test = parse_binary_mnist(data_dir)
39+
40+
data_dict = {"train": train, "valid": validation, "test": test}
41+
f = h5py.File(fname, "w")
42+
f.create_dataset("train", data=data_dict["train"])
43+
f.create_dataset("valid", data=data_dict["valid"])
44+
f.create_dataset("test", data=data_dict["test"])
45+
f.close()
46+
print(f"Saved binary MNIST data to: {fname}")
47+
48+
49+
def load_binary_mnist(fname, batch_size, test_batch_size, use_gpu):
50+
f = h5py.File(fname, "r")
51+
x_train = f["train"][::]
52+
x_val = f["valid"][::]
53+
x_test = f["test"][::]
54+
train = torch.utils.data.TensorDataset(torch.from_numpy(x_train))
55+
kwargs = {"num_workers": 4, "pin_memory": True} if use_gpu else {}
56+
train_loader = torch.utils.data.DataLoader(
57+
train, batch_size=batch_size, shuffle=True, **kwargs
58+
)
59+
validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val))
60+
val_loader = torch.utils.data.DataLoader(
61+
validation, batch_size=test_batch_size, shuffle=False, **kwargs
62+
)
63+
test = torch.utils.data.TensorDataset(torch.from_numpy(x_test))
64+
test_loader = torch.utils.data.DataLoader(
65+
test, batch_size=test_batch_size, shuffle=False, **kwargs
66+
)
67+
return train_loader, val_loader, test_loader

0 commit comments

Comments
 (0)