Skip to content

Commit 6ca017b

Browse files
committed
anime dataset and cifar training done
1 parent 9cc8f09 commit 6ca017b

File tree

9 files changed

+80
-49
lines changed

9 files changed

+80
-49
lines changed
+7-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
{
22
"training_params": {
33
"batch_size": 32,
4-
"epochs": 100,
5-
"lr": 0.001
4+
"epochs": 60,
5+
"lr": 0.0001
66
},
77
"architecture_params": {
8-
"sizes": [3, 8, 32, 64, 128],
9-
"h": 32,
10-
"w": 32,
8+
"sizes": [3, 64, 128, 256, 512],
9+
"h": 88,
10+
"w": 88,
1111
"num_dense_layers": 2,
1212
"fcnn": false
1313

1414
},
1515
"dataset_params": {
16-
"name": "cifar",
16+
"name": "animefacedataset",
1717
"hyperparams": {
18-
"batch_size" : 32,
19-
"classes": ["frog"]
18+
"batch_size" : 32
2019
}
2120
}
2221
}

Experiments/CNNAutoencoder/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def train(model, train_loader, val_loader, device, epochs, lr, batch_size):
3939

4040
for epoch in range(1, epochs + 1):
4141

42-
model.train()
42+
model.train()
4343
for x, _ in tqdm.tqdm(train_loader):
4444
x = x.to(device=device)
4545
photometric_loss = compute_forward_pass(

Experiments/CNNAutoencoder/visualize.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ def visualize(fp, architecture_params, dataloader_params, dataloader_func, resum
4141
axarr[i, 3].imshow(torch.permute(
4242
torch.squeeze(autoencoder(torch.unsqueeze(x[2*i + 1], axis=0))[1]),
4343
(1, 2, 0)).detach().cpu().numpy())
44-
plt.show()
44+
plt.savefig('foo.png')
45+
27.1 MB
Binary file not shown.

foo.png

43.2 KB
Loading

models/cnn_generator.py

+13-39
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def forward(self, x):
3939
class CNNEncoder(nn.Module):
4040
def __init__(self, sizes):
4141
super().__init__()
42-
self.out_seq = nn.Sequential(*[PoolingDownsampleBlock(size_in, size_out) for size_in, size_out
42+
self.out_seq = nn.Sequential(*[DownsampleBlock(size_in, size_out) for size_in, size_out
4343
in zip(sizes[0:-1], sizes[1:])])
4444

4545
def forward(self, x):
@@ -52,11 +52,11 @@ def __init__(self, sizes):
5252
super().__init__()
5353
sizes = list(reversed(sizes))
5454
sizes_minus_last = sizes[0:-1]
55-
self.in_seq = nn.Sequential(*[UnPoolingUpsampleBlock(size_in, size_out, "relu") for size_in, size_out
55+
self.in_seq = nn.Sequential(*[UpsampleBlock(size_in, size_out, "relu") for size_in, size_out
5656
in zip(sizes_minus_last[0:-1], sizes_minus_last[1:])])
5757

58-
self.last = UnPoolingUpsampleBlock(
59-
sizes[-2], sizes[-1], activation="relu")
58+
self.last = UpsampleBlock(
59+
sizes[-2], sizes[-1], activation="sigmoid")
6060

6161
def forward(self, x):
6262
x = self.in_seq(x)
@@ -69,11 +69,14 @@ def __init__(self, size_in, size_out):
6969
super().__init__()
7070
# Modify this to create new conv blocks
7171
# Eg: Throw in pooling, throw in residual connections ... whatever you want
72-
self.conv_1 = nn.Conv2d(size_in, size_out, 3, padding="valid")
72+
self.conv_1 = nn.Conv2d(
73+
size_in, size_out, kernel_size=3, stride=2, padding=1)
74+
self.bn_1 = nn.BatchNorm2d(size_out)
7375
self.act = nn.ReLU()
7476

7577
def forward(self, x):
7678
x = self.conv_1(x)
79+
x = self.bn_1(x)
7780
return self.act(x)
7881

7982

@@ -82,47 +85,18 @@ def __init__(self, size_in, size_out, activation):
8285
super().__init__()
8386
# Modify this to create new transpose conv blocks
8487
# Eg: Throw in dropout, throw in batchnorm ... whatvever you want
85-
self.up_conv_1 = nn.ConvTranspose2d(size_in, size_out, 3)
88+
self.up_conv_1 = nn.ConvTranspose2d(
89+
size_in, size_out, kernel_size=3, stride=2, padding=1, output_padding=1)
8690
activations = nn.ModuleDict([
8791
["relu", nn.ReLU()],
8892
["sigmoid", nn.Sigmoid()],
8993
["tanh", nn.Tanh()]
9094
])
95+
self.bn_1 = nn.BatchNorm2d(size_out)
96+
9197
self.act = activations[activation]
9298

9399
def forward(self, x):
94100
x = self.up_conv_1(x)
101+
x = self.bn_1(x)
95102
return self.act(x)
96-
97-
class PoolingDownsampleBlock(nn.Module):
98-
def __init__(self, size_in, size_out):
99-
super().__init__()
100-
# Modify this to create new conv blocks
101-
# Eg: Throw in pooling, throw in residual connections ... whatever you want
102-
self.conv_1 = nn.Conv2d(size_in, size_out, 3, padding="valid")
103-
self.pool = nn.Conv2d(size_out, size_out, 3, padding="valid")
104-
#self.pool = nn.MaxPool2d(3, 1)
105-
self.act = nn.ReLU()
106-
def forward(self, x):
107-
x = self.conv_1(x)
108-
x = self.pool(x)
109-
return self.act(x)
110-
111-
class UnPoolingUpsampleBlock(nn.Module):
112-
def __init__(self, size_in, size_out, activation):
113-
super().__init__()
114-
# Modify this to create new transpose conv blocks
115-
# Eg: Throw in dropout, throw in batchnorm ... whatvever you want
116-
self.up_conv_1 = nn.ConvTranspose2d(size_in, size_out, 3)
117-
self.up_conv_2 = nn.ConvTranspose2d(size_out, size_out, 3)
118-
119-
activations = nn.ModuleDict([
120-
["relu", nn.ReLU()],
121-
["sigmoid", nn.Sigmoid()],
122-
["tanh", nn.Tanh()]
123-
])
124-
self.act = activations[activation]
125-
def forward(self, x):
126-
x = self.up_conv_1(x)
127-
x = self.up_conv_2(x)
128-
return self.act(x)

sh_scripts/download_anime.sh

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mkdir ../.kaggle
2+
mv kaggle.json ~/.kaggle
3+
kaggle datasets download -d splcher/animefacedataset
4+
unzip animefacedataset.zip -d data/animefacedataset
5+
rm -r -f animefacedataset.zip

utils/datasets/animefacedataset.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
5+
import os
6+
from PIL import Image
7+
8+
import torchvision
9+
import torchvision.transforms as transforms
10+
import torchvision.transforms.functional as F
11+
12+
13+
class AnimeFaceDataset(torch.utils.data.Dataset):
14+
def __init__(self):
15+
16+
self.images = []
17+
base_path = "data/animefacedataset/images"
18+
self.images += [(os.path.join(base_path, pth), 0)
19+
for pth in os.listdir(os.path.join(base_path))]
20+
self.transforms = transforms.Compose(
21+
[transforms.ToTensor(), transforms.Resize((88, 88))])
22+
23+
def __len__(self):
24+
return len(self.images)
25+
26+
def __getitem__(self, idx):
27+
im = Image.open(self.images[idx][0])
28+
return self.transforms((im)), self.images[idx][1]
29+
30+
31+
def create_dataloaders(batch_size):
32+
33+
# insert logic for creating the dataloaders
34+
train = torch.utils.data.DataLoader(
35+
AnimeFaceDataset(),
36+
batch_size=batch_size, shuffle=True)
37+
38+
test = torch.utils.data.DataLoader(
39+
AnimeFaceDataset(),
40+
batch_size=batch_size, shuffle=True)
41+
return train, test
42+
43+
44+
if __name__ == "__main__":
45+
a, b = create_dataloaders(**{"batch_size": 1})
46+
min_a = (np.inf, np.inf)
47+
48+
for i in a:
49+
assert ((i[0].shape[2] == 88) and (i[0].shape[3] == 88))
50+
for i in b:
51+
assert ((i[0].shape[2] == 88) and (i[0].shape[3] == 88))

utils/datasets/cifar.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def create_dataloaders(batch_size, classes):
7777
if not os.path.exists(os.path.join("data", "cifar10")):
7878
trainloader, testloader = get_pytorch_dataloaders()
7979
save_dataset(trainloader, testloader)
80+
8081
# insert logic for creating the dataloaders
8182
train = torch.utils.data.DataLoader(
8283
CIFARDataset("cifar_10_segmented_train",

0 commit comments

Comments
 (0)