Skip to content

Commit 7660c5e

Browse files
committed
Added dataset configs and the ability to customize the dataset that you train on
1 parent 1f4a33b commit 7660c5e

File tree

23 files changed

+128
-208
lines changed

23 files changed

+128
-208
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Project Specific
22
data/
33
test.py
4+
TODO.md
45

56
# Byte-compiled / optimized / DLL files
67
__pycache__/

Experiments/ContractiveAutoencoder/config.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,11 @@
1313
"decoder_activation": "sigmoid",
1414
"final_activation": "sigmoid",
1515
"bias": true
16+
},
17+
"dataset_params": {
18+
"name": "mnist",
19+
"hyperparams":{
20+
"batch_size":32
21+
}
1622
}
1723
}

Experiments/ContractiveAutoencoder/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515

1616
from models.dense_generator import Autoencoder
17-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist
1817
from utils.TorchUtils.training.StatsTracker import StatsTracker
1918

2019

2120
"""
2221
The below code was adapted from
2322
https://stackoverflow.com/questions/58249160/how-to-implement-contractive-autoencoder-in-pytorch
2423
"""
24+
25+
2526
def compute_forward_pass(model, x, optimizer, weight, device, update):
2627
# Flip on the grad switches for the GT tensor
2728
x.requires_grad_(True)
@@ -34,8 +35,9 @@ def compute_forward_pass(model, x, optimizer, weight, device, update):
3435
# the latent once more for the MSE term. We pass in downstream gradient of ones, because we only want a gradient
3536
# of dz/dx, so dot producting dy/dz set to a vector of ones with dz/dx returns just dz/dx
3637
latent.backward(torch.ones(latent.size()).to(device), retain_graph=True)
37-
38-
loss2 = torch.sqrt(torch.sum(torch.pow(x.grad, 2))) # Comptue the frobenius norm on the gradients
38+
39+
# Comptue the frobenius norm on the gradients
40+
loss2 = torch.sqrt(torch.sum(torch.pow(x.grad, 2)))
3941
x.grad.data.zero_()
4042
loss = reconstruction_loss + (weight*loss2)
4143
x.requires_grad_(False)
@@ -45,11 +47,13 @@ def compute_forward_pass(model, x, optimizer, weight, device, update):
4547
optimizer.step()
4648
return loss
4749

50+
4851
def compute_mse(model, x):
4952
latent, reconstruction = model(x)
5053
photometric_loss = MSELoss(reduction="sum")(reconstruction, x)
5154
return photometric_loss
5255

56+
5357
def train(model, train_loader, val_loader, device, epochs, lr, batch_size, weight):
5458
# Initialize autoencoder
5559

@@ -68,7 +72,7 @@ def train(model, train_loader, val_loader, device, epochs, lr, batch_size, weigh
6872
photometric_loss = compute_forward_pass(
6973
model, x, optimizer, weight, device, update=True)
7074
statsTracker.update_curr_losses(photometric_loss.item(), None)
71-
75+
7276
with torch.no_grad():
7377
model.eval()
7478
for x, _ in tqdm.tqdm(val_loader):
@@ -97,13 +101,11 @@ def train(model, train_loader, val_loader, device, epochs, lr, batch_size, weigh
97101
return statsTracker.best_model
98102

99103

100-
def run_experiment(fp, training_params, architecture_params, resume):
101-
batch_size = training_params["batch_size"]
102-
104+
def run_experiment(fp, training_params, architecture_params, dataset_params, dataloader_func, resume):
103105
device = (torch.device('cuda') if torch.cuda.is_available()
104106
else torch.device('cpu'))
105107

106-
train_loader, val_loader = create_dataloaders_mnist(batch_size=batch_size)
108+
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
107109

108110
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
109111

Experiments/ContractiveAutoencoder/visualize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import torch
33
from models.dense_generator import Autoencoder, Encoder
44
from torch.nn import MSELoss
5-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist
65
import numpy as np
76

87
import matplotlib.pyplot as plt
98

109

11-
def visualize(fp, architecture_params, resume):
10+
def visualize(fp, architecture_params, dataloader_params, dataloader_func, resume):
1211
device = (torch.device('cuda') if torch.cuda.is_available()
1312
else torch.device('cpu'))
1413

@@ -20,8 +19,9 @@ def visualize(fp, architecture_params, resume):
2019
# Autoencoder architecture
2120
print(autoencoder)
2221

23-
train_loader, val_loader = create_dataloaders_mnist(batch_size=4)
24-
22+
train_loader, val_loader = dataloader_func(
23+
**dataloader_params["hyperparams"])
24+
2525
# Sample random datapoint
2626
x, _ = next(iter(train_loader))
2727
x = x.to(device=device)

Experiments/DenoisingAutoencoder/config.json

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"training_params": {
33
"batch_size": 32,
44
"epochs": 50,
5-
"lr": 0.00001,
5+
"lr": 0.001,
66
"prob": 0.5
77
},
88
"architecture_params": {
@@ -13,5 +13,11 @@
1313
"decoder_activation": "sigmoid",
1414
"final_activation": "sigmoid",
1515
"bias": true
16+
},
17+
"dataset_params": {
18+
"name": "mnist",
19+
"hyperparams":{
20+
"batch_size":32
21+
}
1622
}
1723
}

Experiments/DenoisingAutoencoder/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
from models.dense_generator import Autoencoder, Encoder
16-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist, DropoutPixelsTransform
16+
from utils.datasets.mnist import DropoutPixelsTransform
1717
from utils.TorchUtils.training.StatsTracker import StatsTracker
1818

1919

@@ -77,11 +77,11 @@ def train(model, train_loader, val_loader, device, epochs, lr, batch_size, prob=
7777
return statsTracker.best_model
7878

7979

80-
def run_experiment(fp, training_params, architecture_params, resume):
80+
def run_experiment(fp, training_params, architecture_params, dataset_params, dataloader_func, resume):
8181
device = (torch.device('cuda') if torch.cuda.is_available()
8282
else torch.device('cpu'))
8383

84-
train_loader, val_loader = create_dataloaders_mnist(batch_size=training_params["batch_size"])
84+
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
8585

8686
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
8787

Experiments/DenoisingAutoencoder/visualize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import torch
33
from models.dense_generator import Autoencoder, Encoder
44
from torch.nn import MSELoss
5-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist, DropoutPixelsTransform
5+
from utils.datasets.mnist import DropoutPixelsTransform
66
import numpy as np
77

88
import matplotlib.pyplot as plt
99

1010

11-
def visualize(fp, architecture_params, resume):
11+
def visualize(fp, architecture_params, dataloader_params, dataloader_func, resume):
1212
device = (torch.device('cuda') if torch.cuda.is_available()
1313
else torch.device('cpu'))
1414

@@ -21,7 +21,8 @@ def visualize(fp, architecture_params, resume):
2121
# Autoencoder architecture
2222
print(autoencoder)
2323

24-
train_loader, val_loader = create_dataloaders_mnist(batch_size=4)
24+
train_loader, val_loader = dataloader_func(
25+
**dataloader_params["hyperparams"])
2526

2627
dropout_transform = DropoutPixelsTransform(0.5)
2728
# Sample random datapoint

Experiments/SparseAutoencoderReg/config.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,11 @@
1313
"decoder_activation": "sigmoid",
1414
"final_activation": "sigmoid",
1515
"bias": true
16+
},
17+
"dataset_params": {
18+
"name": "mnist",
19+
"hyperparams":{
20+
"batch_size":32
21+
}
1622
}
1723
}

Experiments/SparseAutoencoderReg/train.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
from models.dense_generator import Autoencoder
17-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist
1817
from utils.TorchUtils.training.StatsTracker import StatsTracker
1918

2019

@@ -83,14 +82,11 @@ def train(model, train_loader, val_loader, device, epochs, lr, batch_size, weigh
8382
return statsTracker.best_model
8483

8584

86-
def run_experiment(fp, training_params, architecture_params, resume):
87-
batch_size = training_params["batch_size"]
88-
89-
85+
def run_experiment(fp, training_params, architecture_params, dataset_params, dataloader_func, resume):
9086
device = (torch.device('cuda') if torch.cuda.is_available()
9187
else torch.device('cpu'))
9288

93-
train_loader, val_loader = create_dataloaders_mnist(batch_size=batch_size)
89+
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
9490

9591
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
9692

Experiments/SparseAutoencoderReg/visualize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import torch
33
from models.dense_generator import Autoencoder, Encoder
44
from torch.nn import MSELoss
5-
from utils.datasets.mnist_dataloaders import create_dataloaders_mnist
65
import numpy as np
76

87
import matplotlib.pyplot as plt
98

109

11-
def visualize(fp, architecture_params, resume):
10+
def visualize(fp, architecture_params, dataloader_params, dataloader_func, resume):
1211
device = (torch.device('cuda') if torch.cuda.is_available()
1312
else torch.device('cpu'))
1413

@@ -20,8 +19,9 @@ def visualize(fp, architecture_params, resume):
2019
# Autoencoder architecture
2120
print(autoencoder)
2221

23-
train_loader, val_loader = create_dataloaders_mnist(batch_size=4)
24-
22+
train_loader, val_loader = dataloader_func(
23+
**dataloader_params["hyperparams"])
24+
2525
# Sample random datapoint
2626
x, _ = next(iter(train_loader))
2727
x = x.to(device=device)

0 commit comments

Comments
 (0)