Skip to content

Commit 9cc8f09

Browse files
committed
CNN stuff
1 parent 7660c5e commit 9cc8f09

File tree

18 files changed

+295
-65
lines changed

18 files changed

+295
-65
lines changed

Experiments/CNNAutoencoder/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"training_params": {
3+
"batch_size": 32,
4+
"epochs": 100,
5+
"lr": 0.001
6+
},
7+
"architecture_params": {
8+
"sizes": [3, 8, 32, 64, 128],
9+
"h": 32,
10+
"w": 32,
11+
"num_dense_layers": 2,
12+
"fcnn": false
13+
14+
},
15+
"dataset_params": {
16+
"name": "cifar",
17+
"hyperparams": {
18+
"batch_size" : 32,
19+
"classes": ["frog"]
20+
}
21+
}
22+
}

Experiments/CNNAutoencoder/train.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
This experiment is a simple vanilla autoencoder
3+
"""
4+
5+
import os
6+
import torch
7+
import tqdm
8+
import torch.nn.functional as F
9+
from torch.optim import Adam
10+
from torch.optim.lr_scheduler import ReduceLROnPlateau
11+
from torch.nn import MSELoss
12+
13+
14+
from models.cnn_generator import CNNAutoencoder
15+
from utils.TorchUtils.training.StatsTracker import StatsTracker
16+
17+
18+
def compute_forward_pass(model, x, optimizer, criterion, update):
19+
latent, reconstruction = model(x)
20+
21+
photometric_loss = criterion(reconstruction, x)
22+
if update:
23+
model.zero_grad()
24+
photometric_loss.backward()
25+
optimizer.step()
26+
return photometric_loss
27+
28+
29+
def train(model, train_loader, val_loader, device, epochs, lr, batch_size):
30+
# Initialize autoencoder
31+
32+
optimizer = Adam(params=model.parameters(), lr=lr)
33+
scheduler = ReduceLROnPlateau(
34+
optimizer, 'min', factor=0.1, patience=3, min_lr=0.00001, verbose=True)
35+
36+
statsTracker = StatsTracker(
37+
batch_size * len(train_loader), batch_size * len(val_loader))
38+
criterion = MSELoss(reduction="sum")
39+
40+
for epoch in range(1, epochs + 1):
41+
42+
model.train()
43+
for x, _ in tqdm.tqdm(train_loader):
44+
x = x.to(device=device)
45+
photometric_loss = compute_forward_pass(
46+
model, x, optimizer, criterion, update=True)
47+
statsTracker.update_curr_losses(photometric_loss.item(), None)
48+
49+
with torch.no_grad():
50+
model.eval()
51+
for x, _ in tqdm.tqdm(val_loader):
52+
x = x.to(device=device)
53+
photometric_loss_val = compute_forward_pass(
54+
model, x, optimizer, criterion, update=False)
55+
56+
statsTracker.update_curr_losses(
57+
None, photometric_loss_val.item())
58+
59+
train_loss_epoch, val_loss_epoch = statsTracker.compute_means()
60+
assert((statsTracker.train_loss_curr /
61+
(batch_size * len(train_loader))) == train_loss_epoch)
62+
assert((statsTracker.val_loss_curr /
63+
(batch_size * len(val_loader))) == val_loss_epoch)
64+
65+
statsTracker.update_histories(train_loss_epoch, None)
66+
67+
statsTracker.update_histories(None, val_loss_epoch, model)
68+
69+
scheduler.step(val_loss_epoch)
70+
print('Student_network, Epoch {}, Train Loss {}, Val Loss {}'.format(
71+
epoch, round(train_loss_epoch, 6), round(val_loss_epoch, 6)))
72+
73+
statsTracker.reset()
74+
75+
return statsTracker.best_model
76+
77+
78+
def run_experiment(fp, training_params, architecture_params, dataset_params, dataloader_func, resume):
79+
device = (torch.device('cuda') if torch.cuda.is_available()
80+
else torch.device('cpu'))
81+
82+
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
83+
84+
autoencoder = CNNAutoencoder(**(architecture_params)).to(device=device)
85+
86+
if resume:
87+
autoencoder.load_state_dict(torch.load(
88+
os.path.join(fp, "weights/cnn_ae.pt")))
89+
90+
print(autoencoder)
91+
best_model = train(autoencoder, train_loader, val_loader,
92+
device, **(training_params))
93+
torch.save(best_model, os.path.join(fp, "weights/cnn_ae.pt"))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import torch
3+
from models.cnn_generator import CNNAutoencoder
4+
from models.dense_generator import DenseAutoEncoder
5+
import numpy as np
6+
7+
import matplotlib.pyplot as plt
8+
9+
10+
def visualize(fp, architecture_params, dataloader_params, dataloader_func, resume):
11+
device = (torch.device('cuda') if torch.cuda.is_available()
12+
else torch.device('cpu'))
13+
14+
# Create encoder
15+
autoencoder = CNNAutoencoder(**architecture_params).to(device=device)
16+
if resume:
17+
autoencoder.load_state_dict(torch.load(
18+
os.path.join(fp, "weights/cnn_ae.pt")))
19+
20+
# Autoencoder architecture
21+
print(autoencoder)
22+
23+
train_loader, val_loader = dataloader_func(
24+
**dataloader_params["hyperparams"])
25+
26+
# Sample random datapoint
27+
x, _ = next(iter(train_loader))
28+
x = x.to(device=device)
29+
# subplot(r,c) provide the no. of rows and columns
30+
f, axarr = plt.subplots(2, 4)
31+
32+
for i in range(2):
33+
axarr[i, 0].imshow(torch.permute(x[2*i], (1, 2, 0)).detach().cpu().numpy())
34+
axarr[i, 1].imshow(torch.permute(
35+
torch.squeeze(autoencoder(torch.unsqueeze(x[2*i], axis=0))[1]),
36+
(1, 2, 0)).detach().cpu().numpy())
37+
38+
39+
40+
axarr[i, 2].imshow(torch.permute(x[2*i + 1], (1, 2, 0)).detach().cpu().numpy())
41+
axarr[i, 3].imshow(torch.permute(
42+
torch.squeeze(autoencoder(torch.unsqueeze(x[2*i + 1], axis=0))[1]),
43+
(1, 2, 0)).detach().cpu().numpy())
44+
plt.show()
746 KB
Binary file not shown.

Experiments/ContractiveAutoencoder/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.nn import MSELoss
1414

1515

16-
from models.dense_generator import Autoencoder
16+
from models.dense_generator import DenseAutoEncoder
1717
from utils.TorchUtils.training.StatsTracker import StatsTracker
1818

1919

@@ -107,7 +107,7 @@ def run_experiment(fp, training_params, architecture_params, dataset_params, dat
107107

108108
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
109109

110-
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
110+
autoencoder = DenseAutoEncoder(**(architecture_params)).to(device=device)
111111

112112
if resume:
113113
autoencoder.load_state_dict(torch.load(

Experiments/ContractiveAutoencoder/visualize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import torch
3-
from models.dense_generator import Autoencoder, Encoder
3+
from models.dense_generator import DenseAutoEncoder
44
from torch.nn import MSELoss
55
import numpy as np
66

@@ -12,16 +12,17 @@ def visualize(fp, architecture_params, dataloader_params, dataloader_func, resum
1212
else torch.device('cpu'))
1313

1414
# Create encoder
15-
autoencoder = Autoencoder(**architecture_params).to(device=device)
15+
autoencoder = DenseAutoEncoder(**architecture_params).to(device=device)
1616
if resume:
17-
autoencoder.load_state_dict(torch.load(os.path.join(fp, "weights/CAE_weights.pt")))
17+
autoencoder.load_state_dict(torch.load(
18+
os.path.join(fp, "weights/CAE_weights.pt")))
1819

1920
# Autoencoder architecture
2021
print(autoencoder)
2122

2223
train_loader, val_loader = dataloader_func(
2324
**dataloader_params["hyperparams"])
24-
25+
2526
# Sample random datapoint
2627
x, _ = next(iter(train_loader))
2728
x = x.to(device=device)

Experiments/DenoisingAutoencoder/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.nn import MSELoss
1313

1414

15-
from models.dense_generator import Autoencoder, Encoder
15+
from models.dense_generator import DenseAutoEncoder, DenseEncoder
1616
from utils.datasets.mnist import DropoutPixelsTransform
1717
from utils.TorchUtils.training.StatsTracker import StatsTracker
1818

@@ -83,7 +83,7 @@ def run_experiment(fp, training_params, architecture_params, dataset_params, dat
8383

8484
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
8585

86-
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
86+
autoencoder = DenseAutoEncoder(**(architecture_params)).to(device=device)
8787

8888
if resume:
8989
autoencoder.load_state_dict(torch.load(

Experiments/DenoisingAutoencoder/visualize.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import torch
3-
from models.dense_generator import Autoencoder, Encoder
3+
from models.dense_generator import DenseAutoEncoder, DenseEncoder
44
from torch.nn import MSELoss
55
from utils.datasets.mnist import DropoutPixelsTransform
66
import numpy as np
@@ -13,7 +13,7 @@ def visualize(fp, architecture_params, dataloader_params, dataloader_func, resum
1313
else torch.device('cpu'))
1414

1515
# Create encoder
16-
autoencoder = Autoencoder(**architecture_params).to(device=device)
16+
autoencoder = DenseAutoEncoder(**architecture_params).to(device=device)
1717
if resume:
1818
autoencoder.load_state_dict(torch.load(
1919
os.path.join(fp, "weights/denoisingae.pt")))
@@ -31,8 +31,7 @@ def visualize(fp, architecture_params, dataloader_params, dataloader_func, resum
3131
x = dropout_transform(target)
3232

3333
# subplot(r,c) provide the no. of rows and columns
34-
f, axarr = plt.subplots(2, 6, constrained_layout=True, figsize = [8,2])
35-
34+
f, axarr = plt.subplots(2, 6, constrained_layout=True, figsize=[8, 2])
3635

3736
for i in range(2):
3837
(axarr[i, 0]).title.set_text("Original")
@@ -56,7 +55,6 @@ def visualize(fp, architecture_params, dataloader_params, dataloader_func, resum
5655
axarr[i, 5].imshow(torch.reshape(autoencoder(
5756
x[2*i + 1])[1], torch.Size([28, 28, 1])).detach().cpu().numpy())
5857

59-
6058
for i in range(2):
6159
for j in range(6):
6260
(axarr[i, j]).set_xticks([])

Experiments/SparseAutoencoderReg/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.nn import MSELoss
1414

1515

16-
from models.dense_generator import Autoencoder
16+
from models.dense_generator import DenseAutoEncoder
1717
from utils.TorchUtils.training.StatsTracker import StatsTracker
1818

1919

@@ -88,7 +88,7 @@ def run_experiment(fp, training_params, architecture_params, dataset_params, dat
8888

8989
train_loader, val_loader = dataloader_func(**dataset_params["hyperparams"])
9090

91-
autoencoder = Autoencoder(**(architecture_params)).to(device=device)
91+
autoencoder = DenseAutoEncoder(**(architecture_params)).to(device=device)
9292

9393
if resume:
9494
autoencoder.load_state_dict(torch.load(

0 commit comments

Comments
 (0)