forked from heykeetae/Self-Attention-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
690 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import torchvision.datasets as dsets | ||
from torchvision import transforms | ||
|
||
|
||
class Data_Loader(): | ||
def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True): | ||
self.dataset = dataset | ||
self.path = image_path | ||
self.imsize = image_size | ||
self.batch = batch_size | ||
self.shuf = shuf | ||
self.train = train | ||
|
||
def transform(self, resize, totensor, normalize, centercrop): | ||
options = [] | ||
if centercrop: | ||
options.append(transforms.CenterCrop(160)) | ||
if resize: | ||
options.append(transforms.Resize((self.imsize,self.imsize))) | ||
if totensor: | ||
options.append(transforms.ToTensor()) | ||
if normalize: | ||
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) | ||
transform = transforms.Compose(options) | ||
return transform | ||
|
||
def load_lsun(self, classes='church_outdoor_train'): | ||
transforms = self.transform(True, True, True, False) | ||
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms) | ||
return dataset | ||
|
||
def load_celeb(self): | ||
transforms = self.transform(True, True, True, True) | ||
dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms) | ||
return dataset | ||
|
||
|
||
def loader(self): | ||
if self.dataset == 'lsun': | ||
dataset = self.load_lsun() | ||
elif self.dataset == 'celeb': | ||
dataset = self.load_celeb() | ||
|
||
loader = torch.utils.data.DataLoader(dataset=dataset, | ||
batch_size=self.batch, | ||
shuffle=self.shuf, | ||
num_workers=2, | ||
drop_last=True) | ||
return loader | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
FILE=$1 | ||
|
||
if [ $FILE == 'CelebA' ] | ||
then | ||
URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0 | ||
ZIP_FILE=./data/CelebA.zip | ||
|
||
elif [ $FILE == 'LSUN' ] | ||
then | ||
URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0 | ||
ZIP_FILE=./data/church_outdoor_train_lmdb.zip | ||
else | ||
echo "Available datasets are: CelebA and LSUN" | ||
exit 1 | ||
fi | ||
|
||
mkdir -p ./data/ | ||
wget -N $URL -O $ZIP_FILE | ||
unzip $ZIP_FILE -d ./data/ | ||
|
||
if [ $FILE == 'CelebA' ] | ||
then | ||
mv ./data/CelebA_nocrop ./data/CelebA | ||
fi | ||
|
||
rm $ZIP_FILE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
|
||
from parameter import * | ||
from trainer import Trainer | ||
from qgan_trainer import Trainer as qgan_trainer | ||
# from tester import Tester | ||
from data_loader import Data_Loader | ||
from torch.backends import cudnn | ||
from utils import make_folder | ||
|
||
def main(config): | ||
# For fast training | ||
cudnn.benchmark = True | ||
|
||
|
||
# Data loader | ||
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize, | ||
config.batch_size, shuf=config.train) | ||
|
||
# Create directories if not exist | ||
make_folder(config.model_save_path, config.version) | ||
make_folder(config.result_path, config.version) | ||
make_folder(config.sample_path, config.version) | ||
make_folder(config.log_path, config.version) | ||
make_folder(config.attn_path, config.version) | ||
|
||
|
||
if config.train: | ||
if config.model=='sagan': | ||
trainer = Trainer(data_loader.loader(), config) | ||
elif config.model == 'qgan': | ||
trainer = qgan_trainer(data_loader.loader(), config) | ||
trainer.train() | ||
else: | ||
tester = Tester(data_loader.loader(), config) | ||
tester.test() | ||
|
||
if __name__ == '__main__': | ||
config = get_parameters() | ||
print(config) | ||
main(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import argparse | ||
|
||
def str2bool(v): | ||
return v.lower() in ('true') | ||
|
||
def get_parameters(): | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
# Model hyper-parameters | ||
parser.add_argument('--model', type=str, default='sagan', choices=['sagan', 'qgan']) | ||
parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge']) | ||
parser.add_argument('--imsize', type=int, default=32) | ||
parser.add_argument('--g_num', type=int, default=5) | ||
parser.add_argument('--z_dim', type=int, default=128) | ||
parser.add_argument('--g_conv_dim', type=int, default=64) | ||
parser.add_argument('--d_conv_dim', type=int, default=64) | ||
parser.add_argument('--lambda_gp', type=float, default=10) | ||
|
||
# Training setting | ||
parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator') | ||
parser.add_argument('--d_iters', type=float, default=5) | ||
parser.add_argument('--batch_size', type=int, default=64) | ||
parser.add_argument('--num_workers', type=int, default=2) | ||
parser.add_argument('--g_lr', type=float, default=0.0001) | ||
parser.add_argument('--d_lr', type=float, default=0.0004) | ||
parser.add_argument('--lr_decay', type=float, default=0.95) | ||
parser.add_argument('--beta1', type=float, default=0.0) | ||
parser.add_argument('--beta2', type=float, default=0.9) | ||
|
||
# using pretrained | ||
parser.add_argument('--pretrained_model', type=int, default=None) | ||
|
||
# gating net | ||
parser.add_argument('--gum_orig', type=float, default=1) # gum start temperature | ||
parser.add_argument('--gum_temp', type=float, default=1) | ||
parser.add_argument('--min_temp', type=float, default=0.01) | ||
parser.add_argument('--gum_temp_decay', type=float, default=0.0001) | ||
parser.add_argument('--step_anneal', type=int, default=1) # epoch to apply decaying | ||
parser.add_argument('--start_anneal', type=int, default=0) # epoch to start annealing | ||
|
||
|
||
# Test setting | ||
parser.add_argument('--test_size', type=int, default=64) | ||
parser.add_argument('--test_model', type=str, default='50000_G.pth') | ||
parser.add_argument('--result_path', type=str, default='./results') | ||
parser.add_argument('--version', type=str, default='Gum') | ||
parser.add_argument('--nrow', type=int, default=8) | ||
parser.add_argument('--ncol', type=int, default=8) | ||
|
||
# Misc | ||
parser.add_argument('--train', type=str2bool, default=True) | ||
parser.add_argument('--parallel', type=str2bool, default=False) | ||
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb', ]) | ||
parser.add_argument('--use_tensorboard', type=str2bool, default=False) | ||
|
||
# Load balance | ||
parser.add_argument('--load_balance_on', type=str2bool, default=False) | ||
parser.add_argument('--load_weight', type=float, default=1.0) # for 2, for 5 1000, for 4500 | ||
|
||
# Path | ||
parser.add_argument('--image_path', type=str, default='./data') | ||
parser.add_argument('--log_path', type=str, default='./logs') | ||
parser.add_argument('--model_save_path', type=str, default='./models') | ||
parser.add_argument('--sample_path', type=str, default='./samples') | ||
parser.add_argument('--attn_path', type=str, default='./attn') | ||
|
||
# Step size | ||
parser.add_argument('--log_step', type=int, default=10) | ||
parser.add_argument('--sample_step', type=int, default=100) | ||
parser.add_argument('--model_save_step', type=float, default=1.0) | ||
|
||
# claculating quantitative measures | ||
parser.add_argument('--score_epoch', type=int, default=3) # = 5 epochs | ||
parser.add_argument('--score_start', type=int, default=3) # start at 5 (default) | ||
|
||
|
||
return parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
from spectral import SpectralNorm | ||
import numpy as np | ||
|
||
|
||
class Self_Attn(nn.Module): | ||
''' | ||
input: batch_size x feature_depth x feature_size x feature_size | ||
attn_score: batch_size x feature_size x feature_size | ||
output: batch_size x feature_depth x feature_size x feature_size | ||
''' | ||
def __init__(self, b_size, imsize, in_dim, activation): | ||
super(Self_Attn, self).__init__() | ||
self.b_size = b_size | ||
self.imsize = imsize | ||
self.in_dim = in_dim | ||
self.f_ = nn.Conv2d(in_dim, int(in_dim/8), 1) | ||
self.g_ = nn.Conv2d(in_dim, int(in_dim/8), 1) | ||
self.h_ = nn.Conv2d(in_dim, in_dim, 1) | ||
if activation == 'relu': | ||
self.activation = F.relu | ||
elif activation == 'lrelu': | ||
self.activation = F.leakyrelu | ||
self.f__ = nn.Conv2d(int(in_dim/8), in_dim*(imsize**2), 1) | ||
self.g__ = nn.Conv2d(int(in_dim/8), in_dim*(imsize**2), 1) | ||
|
||
self.gamma = nn.Parameter(torch.zeros(self.b_size,1,1,1)) | ||
|
||
def forward(self, x): | ||
b_size = x.size(0) | ||
f_size = x.size(-1) | ||
|
||
f_x = self.f__(self.activation(self.f_(x))) # batch x in_dim*f*f x f_size x f_size | ||
g_x = self.g__(self.activation(self.g_(x))) # batch x in_dim*f*f x f_size x f_size | ||
h_x = self.activation(self.h_(x)).unsqueeze(2).repeat(1,1,f_size**2,1,1).contiguous().view(b_size, -1, f_size**2, f_size**2) # batch x in_dim x f*f x f_size x f_size | ||
|
||
f_ready = f_x.contiguous().view(b_size, -1, f_size**2, f_size, f_size).permute(0,1,2,4,3) # batch x in_dim*f*f x f_size2 x f_size1 | ||
g_ready = g_x.contiguous().view(b_size, -1, f_size**2, f_size, f_size) # batch x in_dim*f*f x f_size1 x f_size2 | ||
|
||
attn_dist = torch.mul(f_ready,g_ready).sum(dim=1).contiguous().view(-1,f_size**2) # batch*f*f x f_size1*f_size2 | ||
attn_soft = F.softmax(attn_dist,dim=1).contiguous().view(b_size, f_size**2, f_size**2) # batch x f*f x f*f | ||
attn_score = attn_soft.unsqueeze(1) # batch x 1 x f*f x f*f | ||
|
||
self_attn_map = torch.mul(h_x, attn_score).sum(dim=3).contiguous().view(b_size, -1, f_size, f_size) # batch x in_dim x f*f | ||
self_attn_map = self.gamma*self_attn_map + x | ||
|
||
return self_attn_map, attn_score | ||
|
||
class Generator(nn.Module): | ||
"""Generator.""" | ||
|
||
def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64): | ||
super(Generator, self).__init__() | ||
self.imsize = image_size | ||
layer1 = [] | ||
layer2 = [] | ||
layer3 = [] | ||
last = [] | ||
|
||
repeat_num = int(np.log2(self.imsize)) - 3 | ||
mult = 2 ** repeat_num # 8 | ||
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4))) | ||
layer1.append(nn.BatchNorm2d(conv_dim * mult)) | ||
layer1.append(nn.ReLU()) | ||
|
||
curr_dim = conv_dim * mult | ||
|
||
layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) | ||
layer2.append(nn.BatchNorm2d(int(curr_dim / 2))) | ||
layer2.append(nn.ReLU()) | ||
|
||
curr_dim = int(curr_dim / 2) | ||
|
||
layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) | ||
layer3.append(nn.BatchNorm2d(int(curr_dim / 2))) | ||
layer3.append(nn.ReLU()) | ||
|
||
if self.imsize == 64: | ||
layer4 = [] | ||
curr_dim = int(curr_dim / 2) | ||
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) | ||
layer4.append(nn.BatchNorm2d(int(curr_dim / 2))) | ||
layer4.append(nn.ReLU()) | ||
self.l4 = nn.Sequential(*layer4) | ||
curr_dim = int(curr_dim / 2) | ||
|
||
self.l1 = nn.Sequential(*layer1) | ||
self.l2 = nn.Sequential(*layer2) | ||
self.l3 = nn.Sequential(*layer3) | ||
|
||
last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1)) | ||
last.append(nn.Tanh()) | ||
self.last = nn.Sequential(*last) | ||
|
||
self.attn1 = Self_Attn(batch_size, int(self.imsize/4), 128, 'relu') | ||
self.attn2 = Self_Attn(batch_size, int(self.imsize/2), 64, 'relu') | ||
|
||
def forward(self, z): | ||
z = z.view(z.size(0), z.size(1), 1, 1) | ||
out=self.l1(z) | ||
out=self.l2(out) | ||
out=self.l3(out) | ||
out,p1 = self.attn1(out) | ||
out=self.l4(out) | ||
out,p2 = self.attn2(out) | ||
out=self.last(out) | ||
|
||
return out, p1, p2 | ||
|
||
|
||
class Discriminator(nn.Module): | ||
"""Discriminator, Auxiliary Classifier.""" | ||
|
||
def __init__(self, batch_size=64, image_size=64, conv_dim=64): | ||
super(Discriminator, self).__init__() | ||
self.imsize = image_size | ||
layer1 = [] | ||
layer2 = [] | ||
layer3 = [] | ||
last = [] | ||
|
||
layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1))) | ||
layer1.append(nn.LeakyReLU(0.1)) | ||
|
||
curr_dim = conv_dim | ||
|
||
layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) | ||
layer2.append(nn.LeakyReLU(0.1)) | ||
curr_dim = curr_dim * 2 | ||
|
||
layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) | ||
layer3.append(nn.LeakyReLU(0.1)) | ||
curr_dim = curr_dim * 2 | ||
|
||
if self.imsize == 64: | ||
layer4 = [] | ||
layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) | ||
layer4.append(nn.LeakyReLU(0.1)) | ||
self.l4 = nn.Sequential(*layer4) | ||
curr_dim = curr_dim*2 | ||
self.l1 = nn.Sequential(*layer1) | ||
self.l2 = nn.Sequential(*layer2) | ||
self.l3 = nn.Sequential(*layer3) | ||
|
||
last.append(nn.Conv2d(curr_dim, 1, 4)) | ||
self.last = nn.Sequential(*last) | ||
|
||
self.attn1 = Self_Attn(batch_size, int(self.imsize/8), 256, 'relu') | ||
self.attn2 = Self_Attn(batch_size, int(self.imsize/16), 512, 'relu') | ||
|
||
def forward(self, x): | ||
out = self.l1(x) | ||
out = self.l2(out) | ||
out = self.l3(out) | ||
out,p1 = self.attn1(out) | ||
out=self.l4(out) | ||
out,p2 = self.attn2(out) | ||
out=self.last(out) | ||
|
||
return out.squeeze(), p1, p2 |
Oops, something went wrong.