Skip to content

Test #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/
input/
# input/
models/
checkpoint/
6 changes: 5 additions & 1 deletion src/config.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
USE_TPU = False
MULTI_CORE = False
MIXED_PRECISION = False

import os
import torch
Expand All @@ -8,6 +9,7 @@
OUT_DIR = '../result/'
MODEL_DIR = '../models/'
CHECKPOINT_DIR = '../checkpoint/'
LOGS_DIR = '../logs/'

TRAIN_DIR = DATA_DIR+"train/" # UPDATE
TEST_DIR = DATA_DIR+"test/" # UPDATE
Expand All @@ -17,6 +19,7 @@
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)

# DATA INFORMATION
IMAGE_SIZE = 224
Expand All @@ -37,4 +40,5 @@
else:
DEVICE = 'cpu'


if DEVICE=='cpu' and MIXED_PRECISION:
raise ValueError('To use mixed precision you need GPU')
1 change: 1 addition & 0 deletions src/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ def generate_batch(self):
batch = np.asarray(batch)/255 # values between 0 and 1
labels = np.asarray(labels)/255 # values between 0 and 1
return batch, labels, filelist

25 changes: 12 additions & 13 deletions src/engine.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,40 @@
import torch_xla.distributed.xla_multiprocessing as xmp



def train(train_loader, GAN_Model, netD, VGG_MODEL, optG, optD, device, losses):
batch = 0

def wgan_loss(prediction, real_or_not):
if real_or_not:
return -torch.mean(prediction.float())
return -torch.mean(prediction)
else:
return torch.mean(prediction.float())
return torch.mean(prediction)

def gp_loss(y_pred, averaged_samples, gradient_penalty_weight):

gradients = torch.autograd.grad(y_pred,averaged_samples,
grad_outputs=torch.ones(y_pred.size(), device=device),
grad_outputs=torch.ones(y_pred.size(), device=device, dtype=torch.half if config.MIXED_PRECISION else torch.float),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = (((gradients+1e-16).norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight
gradient_penalty = (((gradients).norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight
return gradient_penalty
for trainL, trainAB, _ in tqdm(iter(train_loader)):
batch += 1

trainL_3 = torch.tensor(np.tile(trainL.cpu(), [1,3,1,1]), device=device)
trainL_3 = torch.tensor(np.tile(trainL.cpu(), [1,3,1,1]), device=device, dtype=torch.half if config.MIXED_PRECISION else torch.float)

trainL = torch.tensor(trainL, device=device).double()
trainAB = torch.tensor(trainAB, device=device).double()
# trainL_3 = trainL_3.to(device).double()
trainL = torch.tensor(trainL, device=device, dtype=torch.half if config.MIXED_PRECISION else torch.float)
trainAB = torch.tensor(trainAB, device=device, dtype=torch.half if config.MIXED_PRECISION else torch.float)

predictVGG = F.softmax(VGG_MODEL(trainL_3))

############ GAN MODEL ( Training Generator) ###################
optG.zero_grad()
predAB, classVector, discpred = GAN_Model(trainL, trainL_3)
D_G_z1 = discpred.mean().item()
Loss_KLD = nn.KLDivLoss(size_average='False')(classVector.log().float(), predictVGG.detach().float()) * 0.003
Loss_MSE = nn.MSELoss()(predAB.float(), trainAB.float())
Loss_WL = wgan_loss(discpred.float(), True) * 0.1
Loss_KLD = nn.KLDivLoss(size_average='False')(classVector.log(), predictVGG.detach()) * 0.003
Loss_MSE = nn.MSELoss()(predAB, trainAB)
Loss_WL = wgan_loss(discpred, True) * 0.1
Loss_G = Loss_KLD + Loss_MSE + Loss_WL
Loss_G.backward()

Expand Down Expand Up @@ -81,7 +79,7 @@ def gp_loss(y_pred, averaged_samples, gradient_penalty_weight):
discreal = netD(realLAB)
D_x = discreal.mean().item()

weights = torch.randn((trainAB.size(0),1,1,1), device=device)
weights = torch.randn((trainAB.size(0),1,1,1), device=device, dtype=torch.half if config.MIXED_PRECISION else torch.float)
averaged_samples = (weights * trainAB ) + ((1 - weights) * predAB.detach())
averaged_samples = torch.autograd.Variable(averaged_samples, requires_grad=True)
avg_img = torch.cat([trainL, averaged_samples], dim=1)
Expand All @@ -103,6 +101,7 @@ def gp_loss(y_pred, averaged_samples, gradient_penalty_weight):

losses['D_losses'].append(Loss_D.item())
losses['EPOCH_D_losses'].append(Loss_D.item())

# Output training stats
if batch % 100 == 0:
print('Loss_D: %.8f | Loss_G: %.8f | D(x): %.8f | D(G(z)): %.8f / %.8f | MSE: %.8f | KLD: %.8f | WGAN_F(G): %.8f | WGAN_F(D): %.8f | WGAN_R(D): %.8f | WGAN_A(D): %.8f'
Expand Down
5 changes: 2 additions & 3 deletions src/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self):

self.VGG_model = torchvision.models.vgg16(pretrained=True)
self.VGG_model = nn.Sequential(*list(self.VGG_model.features.children())[:-8]) #[None, 512, 28, 28]
self.VGG_model = self.VGG_model.double()
self.relu = nn.ReLU()
self.lrelu = nn.LeakyReLU(0.3)
self.global_features_conv1 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(2,2), bias=bias) #[None, 512, 14, 14]
Expand Down Expand Up @@ -83,7 +82,7 @@ def forward(self,input_img):

# VGG Without Top Layers

vgg_out = self.VGG_model(torch.tensor(input_img).double())
vgg_out = self.VGG_model(torch.tensor(input_img))

#Global Features

Expand Down Expand Up @@ -111,7 +110,7 @@ def forward(self,input_img):
global_featureClass = self.softmax(self.global_featuresClass_dense3(global_featureClass))#[None, 1000]

# Mid Level Features
midlevel_features = self.midlevel_conv1(vgg_out.double()) #[None, 512, 28, 28]
midlevel_features = self.midlevel_conv1(vgg_out) #[None, 512, 28, 28]
midlevel_features = self.midlevel_bn1(midlevel_features) #[None, 512, 28, 28]
midlevel_features = self.midlevel_conv2(midlevel_features) #[None, 256, 28, 28]
midlevel_features = self.midlevel_bn2(midlevel_features) #[None, 256, 28, 28]
Expand Down
72 changes: 72 additions & 0 deletions src/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Source: https://gist.github.com/ajbrock/075c0ca4036dc4d8581990
# Adam Optimizer that supports Mixed Precision

import math
from torch.optim.optimizer import Optimizer

class Adam16(Optimizer):

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-4,
weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
params = list(params)
super(Adam16, self).__init__(params, defaults)
# for group in self.param_groups:
# for p in group['params']:

self.fp32_param_groups = [p.data.float().cuda() for p in params]
if not isinstance(self.fp32_param_groups[0], dict):
self.fp32_param_groups = [{'params': self.fp32_param_groups}]

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group,fp32_group in zip(self.param_groups,self.fp32_param_groups):
for p,fp32_p in zip(group['params'],fp32_group['params']):
if p.grad is None:
continue

grad = p.grad.data.float()
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], fp32_p)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

# print(type(fp32_p))
fp32_p.addcdiv_(-step_size, exp_avg, denom)
p.data = fp32_p.half()

return loss

20 changes: 14 additions & 6 deletions src/train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

if config.MIXED_PRECISION:
from optim import Adam16

def map_fn(index=None, flags=None):
torch.set_default_tensor_type('torch.FloatTensor')
# torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(1234)

train_data = dataset.DATA(config.TRAIN_DIR)
Expand Down Expand Up @@ -49,19 +51,25 @@ def map_fn(index=None, flags=None):
DEVICE = config.DEVICE


netG = model.colorization_model().double()
netD = model.discriminator_model().double()
netG = model.colorization_model()
netD = model.discriminator_model()

VGG_modelF = torchvision.models.vgg16(pretrained=True).double()
VGG_modelF = torchvision.models.vgg16(pretrained=True)
VGG_modelF.requires_grad_(False)

netG = netG.to(DEVICE)
netD = netD.to(DEVICE)

if config.MIXED_PRECISION:
VGG_modelF = VGG_modelF.half()
VGG_modelF = VGG_modelF.to(DEVICE)

optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))
if config.MIXED_PRECISION:
optD = Adam16(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
optG = Adam16(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))
else:
optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))

## Trains
train_start = time.time()
Expand Down
18 changes: 9 additions & 9 deletions src/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,15 @@ def plot_some(test_data, colorization_model, device, epoch):
filepath = config.TRAIN_DIR+filename
batchL = batchL.reshape(1,1,224,224)
realAB = realAB.reshape(1,2,224,224)
batchL_3 = torch.tensor(np.tile(batchL, [1, 3, 1, 1]))
batchL_3 = batchL_3.to(device)
batchL = torch.tensor(batchL).to(device).double()
realAB = torch.tensor(realAB).to(device).double()
batchL_3 = torch.tensor(np.tile(batchL, [1, 3, 1, 1]), dtype=torch.half if config.MIXED_PRECISION else torch.float, device=device)
batchL = torch.tensor(batchL, dtype=torch.half if config.MIXED_PRECISION else torch.float, device=device)
realAB = torch.tensor(realAB, dtype=torch.half if config.MIXED_PRECISION else torch.float, device=device)

colorization_model.eval()
batch_predAB, _ = colorization_model(batchL_3)
img = cv2.imread(filepath)
batch_predAB = batch_predAB.cpu().numpy().reshape((224,224,2))
batchL = batchL.cpu().numpy().reshape((224,224,1))
batch_predAB = batch_predAB.cpu().float().numpy().reshape((224,224,2))
batchL = batchL.cpu().float().numpy().reshape((224,224,1))
realAB = realAB.cpu().numpy().reshape((224,224,2))
orig = cv2.imread(filepath)
orig = cv2.resize(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB), (224,224))
Expand Down Expand Up @@ -130,17 +129,18 @@ def load_checkpoint(checkpoint_directory, netG, optG, netD, optD, device):
netD.to(device)

optD.load_state_dict(checkpoint['discriminator_optimizer'])

print('Loaded States !!!')
print(f'It looks like the this states belong to epoch {epoch_checkpoint-1}.')
print(f'It looks like this states belong to epoch {epoch_checkpoint-1}.')
print(f'so the model will train for {config.NUM_EPOCHS - (epoch_checkpoint-1)} more epochs.')
print(f'If you want to train for more epochs, change the "NUM_EPOCHS" in config.py !!')


return netG, optG, netD, optD, epoch_checkpoint
return netG.half() if config.MIXED_PRECISION else netG, optG, netD.half() if config.MIXED_PRECISION else netD, optD, epoch_checkpoint
else:
print('There are no checkpoints in the mentioned directoy, the Model will train from scratch.')
epoch_checkpoint = 1
return netG, optG, netD, optD, epoch_checkpoint
return netG.half() if config.MIXED_PRECISION else netG, optG, netD.half() if config.MIXED_PRECISION else netD, optD, epoch_checkpoint



Expand Down