Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
heykeetae authored May 28, 2018
1 parent 7d30f66 commit e78f98c
Show file tree
Hide file tree
Showing 8 changed files with 690 additions and 0 deletions.
51 changes: 51 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torchvision.datasets as dsets
from torchvision import transforms


class Data_Loader():
def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True):
self.dataset = dataset
self.path = image_path
self.imsize = image_size
self.batch = batch_size
self.shuf = shuf
self.train = train

def transform(self, resize, totensor, normalize, centercrop):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((self.imsize,self.imsize)))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)
return transform

def load_lsun(self, classes='church_outdoor_train'):
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
return dataset

def load_celeb(self):
transforms = self.transform(True, True, True, True)
dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms)
return dataset


def loader(self):
if self.dataset == 'lsun':
dataset = self.load_lsun()
elif self.dataset == 'celeb':
dataset = self.load_celeb()

loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=self.batch,
shuffle=self.shuf,
num_workers=2,
drop_last=True)
return loader

26 changes: 26 additions & 0 deletions download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FILE=$1

if [ $FILE == 'CelebA' ]
then
URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0
ZIP_FILE=./data/CelebA.zip

elif [ $FILE == 'LSUN' ]
then
URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0
ZIP_FILE=./data/church_outdoor_train_lmdb.zip
else
echo "Available datasets are: CelebA and LSUN"
exit 1
fi

mkdir -p ./data/
wget -N $URL -O $ZIP_FILE
unzip $ZIP_FILE -d ./data/

if [ $FILE == 'CelebA' ]
then
mv ./data/CelebA_nocrop ./data/CelebA
fi

rm $ZIP_FILE
40 changes: 40 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

from parameter import *
from trainer import Trainer
from qgan_trainer import Trainer as qgan_trainer
# from tester import Tester
from data_loader import Data_Loader
from torch.backends import cudnn
from utils import make_folder

def main(config):
# For fast training
cudnn.benchmark = True


# Data loader
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=config.train)

# Create directories if not exist
make_folder(config.model_save_path, config.version)
make_folder(config.result_path, config.version)
make_folder(config.sample_path, config.version)
make_folder(config.log_path, config.version)
make_folder(config.attn_path, config.version)


if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
elif config.model == 'qgan':
trainer = qgan_trainer(data_loader.loader(), config)
trainer.train()
else:
tester = Tester(data_loader.loader(), config)
tester.test()

if __name__ == '__main__':
config = get_parameters()
print(config)
main(config)
78 changes: 78 additions & 0 deletions parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse

def str2bool(v):
return v.lower() in ('true')

def get_parameters():

parser = argparse.ArgumentParser()

# Model hyper-parameters
parser.add_argument('--model', type=str, default='sagan', choices=['sagan', 'qgan'])
parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge'])
parser.add_argument('--imsize', type=int, default=32)
parser.add_argument('--g_num', type=int, default=5)
parser.add_argument('--z_dim', type=int, default=128)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--lambda_gp', type=float, default=10)

# Training setting
parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator')
parser.add_argument('--d_iters', type=float, default=5)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--g_lr', type=float, default=0.0001)
parser.add_argument('--d_lr', type=float, default=0.0004)
parser.add_argument('--lr_decay', type=float, default=0.95)
parser.add_argument('--beta1', type=float, default=0.0)
parser.add_argument('--beta2', type=float, default=0.9)

# using pretrained
parser.add_argument('--pretrained_model', type=int, default=None)

# gating net
parser.add_argument('--gum_orig', type=float, default=1) # gum start temperature
parser.add_argument('--gum_temp', type=float, default=1)
parser.add_argument('--min_temp', type=float, default=0.01)
parser.add_argument('--gum_temp_decay', type=float, default=0.0001)
parser.add_argument('--step_anneal', type=int, default=1) # epoch to apply decaying
parser.add_argument('--start_anneal', type=int, default=0) # epoch to start annealing


# Test setting
parser.add_argument('--test_size', type=int, default=64)
parser.add_argument('--test_model', type=str, default='50000_G.pth')
parser.add_argument('--result_path', type=str, default='./results')
parser.add_argument('--version', type=str, default='Gum')
parser.add_argument('--nrow', type=int, default=8)
parser.add_argument('--ncol', type=int, default=8)

# Misc
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb', ])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)

# Load balance
parser.add_argument('--load_balance_on', type=str2bool, default=False)
parser.add_argument('--load_weight', type=float, default=1.0) # for 2, for 5 1000, for 4500

# Path
parser.add_argument('--image_path', type=str, default='./data')
parser.add_argument('--log_path', type=str, default='./logs')
parser.add_argument('--model_save_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./samples')
parser.add_argument('--attn_path', type=str, default='./attn')

# Step size
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=100)
parser.add_argument('--model_save_step', type=float, default=1.0)

# claculating quantitative measures
parser.add_argument('--score_epoch', type=int, default=3) # = 5 epochs
parser.add_argument('--score_start', type=int, default=3) # start at 5 (default)


return parser.parse_args()
163 changes: 163 additions & 0 deletions sagan_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from spectral import SpectralNorm
import numpy as np


class Self_Attn(nn.Module):
'''
input: batch_size x feature_depth x feature_size x feature_size
attn_score: batch_size x feature_size x feature_size
output: batch_size x feature_depth x feature_size x feature_size
'''
def __init__(self, b_size, imsize, in_dim, activation):
super(Self_Attn, self).__init__()
self.b_size = b_size
self.imsize = imsize
self.in_dim = in_dim
self.f_ = nn.Conv2d(in_dim, int(in_dim/8), 1)
self.g_ = nn.Conv2d(in_dim, int(in_dim/8), 1)
self.h_ = nn.Conv2d(in_dim, in_dim, 1)
if activation == 'relu':
self.activation = F.relu
elif activation == 'lrelu':
self.activation = F.leakyrelu
self.f__ = nn.Conv2d(int(in_dim/8), in_dim*(imsize**2), 1)
self.g__ = nn.Conv2d(int(in_dim/8), in_dim*(imsize**2), 1)

self.gamma = nn.Parameter(torch.zeros(self.b_size,1,1,1))

def forward(self, x):
b_size = x.size(0)
f_size = x.size(-1)

f_x = self.f__(self.activation(self.f_(x))) # batch x in_dim*f*f x f_size x f_size
g_x = self.g__(self.activation(self.g_(x))) # batch x in_dim*f*f x f_size x f_size
h_x = self.activation(self.h_(x)).unsqueeze(2).repeat(1,1,f_size**2,1,1).contiguous().view(b_size, -1, f_size**2, f_size**2) # batch x in_dim x f*f x f_size x f_size

f_ready = f_x.contiguous().view(b_size, -1, f_size**2, f_size, f_size).permute(0,1,2,4,3) # batch x in_dim*f*f x f_size2 x f_size1
g_ready = g_x.contiguous().view(b_size, -1, f_size**2, f_size, f_size) # batch x in_dim*f*f x f_size1 x f_size2

attn_dist = torch.mul(f_ready,g_ready).sum(dim=1).contiguous().view(-1,f_size**2) # batch*f*f x f_size1*f_size2
attn_soft = F.softmax(attn_dist,dim=1).contiguous().view(b_size, f_size**2, f_size**2) # batch x f*f x f*f
attn_score = attn_soft.unsqueeze(1) # batch x 1 x f*f x f*f

self_attn_map = torch.mul(h_x, attn_score).sum(dim=3).contiguous().view(b_size, -1, f_size, f_size) # batch x in_dim x f*f
self_attn_map = self.gamma*self_attn_map + x

return self_attn_map, attn_score

class Generator(nn.Module):
"""Generator."""

def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
super(Generator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []

repeat_num = int(np.log2(self.imsize)) - 3
mult = 2 ** repeat_num # 8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
layer1.append(nn.BatchNorm2d(conv_dim * mult))
layer1.append(nn.ReLU())

curr_dim = conv_dim * mult

layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())

curr_dim = int(curr_dim / 2)

layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())

if self.imsize == 64:
layer4 = []
curr_dim = int(curr_dim / 2)
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer4.append(nn.ReLU())
self.l4 = nn.Sequential(*layer4)
curr_dim = int(curr_dim / 2)

self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)

last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
self.last = nn.Sequential(*last)

self.attn1 = Self_Attn(batch_size, int(self.imsize/4), 128, 'relu')
self.attn2 = Self_Attn(batch_size, int(self.imsize/2), 64, 'relu')

def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
out=self.l1(z)
out=self.l2(out)
out=self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)

return out, p1, p2


class Discriminator(nn.Module):
"""Discriminator, Auxiliary Classifier."""

def __init__(self, batch_size=64, image_size=64, conv_dim=64):
super(Discriminator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []

layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
layer1.append(nn.LeakyReLU(0.1))

curr_dim = conv_dim

layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer2.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2

layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer3.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2

if self.imsize == 64:
layer4 = []
layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer4.append(nn.LeakyReLU(0.1))
self.l4 = nn.Sequential(*layer4)
curr_dim = curr_dim*2
self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)

last.append(nn.Conv2d(curr_dim, 1, 4))
self.last = nn.Sequential(*last)

self.attn1 = Self_Attn(batch_size, int(self.imsize/8), 256, 'relu')
self.attn2 = Self_Attn(batch_size, int(self.imsize/16), 512, 'relu')

def forward(self, x):
out = self.l1(x)
out = self.l2(out)
out = self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)

return out.squeeze(), p1, p2
Loading

0 comments on commit e78f98c

Please sign in to comment.