diff --git a/ReadME.md b/ReadME.md new file mode 100644 index 0000000..3e14c30 --- /dev/null +++ b/ReadME.md @@ -0,0 +1,67 @@ +# Code for GANorCON + +This is the code for the contrastive few-shot part segmentation method proposed in 'GANorCON : Are Generative Models Useful for Few-shot Segmentation?'. We provide the pretrained MocoV2 model for ease. + + +## Preparation +``` +pip3 install requirements.txt +``` +-> Download MoCoV2_512_CelebA.pth from https://drive.google.com/file/d/1n0iyFuZ20s_DsIAorvmtVLIdHIG_65N0/ and place inside this folder +-> Place the [DatasetGAN released data](https://drive.google.com/drive/folders/1PSS0uOusN3dV84YLT9Gds1ZSugjpMz7E) at ./DatasetGAN_data inside this folder +-> Download CelebAMask from [here](https://drive.google.com/open?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv) and place inside this folder + +## Training - Few Shot Segmentation + +```bash + +python3 eval_face_seg.py --model resnet50 --segmodel fcn --layer 4 --trained_model_path MoCoV2_512_CelebA.pth --learning_rate 0.001 --weight_decay 0.0005 --adam --epochs 800 --cosine --batch_size 1 --log_path ./log.txt --model_name face_segmentor --model_path ./512_faces_celeba --image_size 512 --use_hypercol + +``` +Option --segmodel can be set to either "fcn" or "UNet" for either variants described in the paper. --model_path can be set to desired location for saving the checkpoints. + +## Generate data for distillation + +```bash + +python3 eval_face_seg.py --model resnet50 --segmodel fcn --layer 4 --trained_model_path MoCoV2_512_CelebA.pth --image_size 512 --use_hypercol --generate --gen_path ./labels_fordeeplab/ --resume ./512_faces_celeba/face_segmentor/resnet50.pth + +``` +Place path to the trained model resnet50.pth in --resume. Option --gen_path is where the generated predicted labels using checkpoint in --resume will be stored. + +## Distillation + +```bash + +python3 train_deeplab_contrast.py --data_path ./labels_fordeeplab/ --model_path ./512_faces_celeba_distilled --image_size 512 --num_classes 34 + +``` +Specify the path to generated labels from previous step in --data_path and specify path to save model in --model_path. + +## Testing + +For model from Few Shot Segmentation training: +```bash + +python3 gen_score_seg.py --resume ./512_faces_celeba/Nvidia_segmentor/ --model fcn + +``` +--model can be also changed to UNet if UNet based segmentor was used to train. + + +For model from distillation: +```bash + +python3 gen_score_seg.py --resume ./512_faces_celeba_distilled/deeplab_class_34_checkpoint/ --distill + +``` + +Place the folder where all checkpoints are stored in --resume for both cases. + + + + + + + + diff --git a/data_loader/data_loader_celebamask.py b/data_loader/data_loader_celebamask.py new file mode 100644 index 0000000..f5ba498 --- /dev/null +++ b/data_loader/data_loader_celebamask.py @@ -0,0 +1,126 @@ +import torch +import torchvision.datasets as dsets +from torchvision import transforms +import torchvision.transforms.functional as TF +from PIL import Image +import os +import numpy as np +import random + +class CelebAMaskHQ(): + def __init__(self, img_path, label_path, transform_img, transform_label, mode): + self.img_path = img_path + self.label_path = label_path + self.transform_img = transform_img + self.transform_label = transform_label + self.train_dataset = [] + self.test_dataset = [] + self.mode = mode + self.preprocess_nvidia() + + if mode == True: + self.num_images = len(self.train_dataset) + else: + self.num_images = len(self.test_dataset) + + def preprocess_nvidia(self): + if self.mode==True: + for i in range(int(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])/2)): + img_path = os.path.join(self.img_path, 'image_'+str(i)+'.jpg') + label_path = os.path.join(self.label_path, 'image_mask'+str(i)+'.npy') + self.train_dataset.append([img_path, label_path]) + else: + for i in range(int(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])/2)): + img_path = os.path.join(self.img_path, 'face_'+str(i)+'.png') + label_path = os.path.join(self.label_path, 'mask_'+str(i)+'.npy') + self.test_dataset.append([img_path, label_path]) + + print('Finished preprocessing the Nvidia dataset...') + + def __getitem__(self, index): + + dataset = self.train_dataset if self.mode == True else self.test_dataset + img_path, label_path = dataset[index] + image = Image.open(img_path) + label = np.load(label_path) + + label = Image.fromarray(label) + image = image.resize((512,512)) + label = label.resize((512, 512), resample=Image.NEAREST) + + crop = random.random() < 0.5 + if crop and self.mode==True: + i, j, h, w = transforms.RandomResizedCrop.get_params( + image, scale=(0.6,1.0), ratio=(0.7,1.3)) + + image = TF.crop(image, i, j, h, w) + label = TF.crop(label, i, j, h, w) + + image = image.resize((512,512)) + label = label.resize((512, 512), resample=Image.NEAREST) + + jitter = random.random() < 0.5 + if jitter and self.mode==True: + image = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0)(image) + + hflip = random.random() < 0.5 + if hflip and self.mode==True: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + label = label.transpose(Image.FLIP_LEFT_RIGHT) + label = np.array(label, dtype=np.long) + + return self.transform_img(image), self.transform_label(label) + + def __len__(self): + """Return the number of images.""" + return self.num_images + +class Data_Loader(): + def __init__(self, img_path, label_path, image_size, batch_size, mode): + self.img_path = img_path + self.label_path = label_path + self.imsize = image_size + self.batch = batch_size + self.mode = mode + + def transform_img(self, resize, totensor, normalize, centercrop): + options = [] + if centercrop: + options.append(transforms.CenterCrop(160)) + if resize: + options.append(transforms.Resize((self.imsize,self.imsize))) + if totensor: + options.append(transforms.ToTensor()) + if normalize: + options.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])) + transform = transforms.Compose(options) + return transform + + def transform_label(self, resize, totensor, normalize, centercrop): + options = [] + if centercrop: + options.append(transforms.CenterCrop(160)) + if resize: + options.append(transforms.Resize((self.imsize,self.imsize))) + if totensor: + options.append(transforms.ToTensor()) + if normalize: + options.append(transforms.Normalize((0, 0, 0), (0, 0, 0))) + transform = transforms.Compose(options) + return transform + + def loader(self): + transform_img = self.transform_img(True, True, True, False) + transform_label = self.transform_label(False, True, False, False) + dataset = CelebAMaskHQ(self.img_path, self.label_path, transform_img, transform_label, self.mode) + + print(len(dataset)) + + loader = torch.utils.data.DataLoader(dataset=dataset, + batch_size=self.batch, + shuffle=False,#self.mode==True, + num_workers=0, + drop_last=False, + pin_memory=True) + return loader diff --git a/data_loader/data_loader_forgen.py b/data_loader/data_loader_forgen.py new file mode 100644 index 0000000..d6fd509 --- /dev/null +++ b/data_loader/data_loader_forgen.py @@ -0,0 +1,30 @@ +from PIL import Image +import torchvision +from torch.utils.data import Dataset + +resnet_transform = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +class ImageLabelDataset(Dataset): + def __init__( + self, + img_path_list, + img_size=(128, 128), + ): + self.img_path_list = img_path_list + self.img_size = img_size + + def __len__(self): + return len(self.img_path_list) + + def __getitem__(self, index): + im_path = self.img_path_list[index] + im = Image.open(im_path) + im = self.transform(im) + return im, im_path + + def transform(self, img): + img = img.resize((self.img_size[0], self.img_size[1])) + img = torchvision.transforms.ToTensor()(img) + img = resnet_transform(img) + return img \ No newline at end of file diff --git a/eval_face_seg.py b/eval_face_seg.py new file mode 100644 index 0000000..2c17327 --- /dev/null +++ b/eval_face_seg.py @@ -0,0 +1,512 @@ +from __future__ import print_function + +import os +import sys +import time +import torch +import torch.optim as optim +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import argparse +import socket +import torch.multiprocessing as mp +import torch.distributed as dist + +import tensorboard_logger as tb_logger + +from torchvision import transforms, datasets +from torch.utils.data import Dataset, DataLoader + +from utils.util import adjust_learning_rate, AverageMeter, Tee + +from models.resnet import InsResNet50,InsResNet18,InsResNet34,InsResNet101,InsResNet152 +from models.segmentor import fcn, UNet +from models.loss import cross_entropy2d + +from data_loader.data_loader_celebamask import Data_Loader +from data_loader.data_loader_forgen import ImageLabelDataset + +import matplotlib.pyplot as plt + +import numpy as np +import random +import math +import cv2 + +torch.manual_seed(0) +torch.cuda.manual_seed(0) +torch.cuda.manual_seed_all(0) +random.seed(0) +np.random.seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + + +def parse_option(): + + hostname = socket.gethostname() + + parser = argparse.ArgumentParser('argument for training') + + parser.add_argument('--print_freq', type=int, default=10, help='print frequency') + parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') + parser.add_argument('--save_freq', type=int, default=20, help='save frequency') + parser.add_argument('--batch_size', type=int, default=2, help='batch_size') + parser.add_argument('--num_workers', type=int, default=16, help='num of workers to use') + parser.add_argument('--epochs', type=int, default=60, help='number of training epochs') + + # optimization + parser.add_argument('--learning_rate', type=float, default=0.1, help='learning rate') + parser.add_argument('--lr_decay_epochs', type=str, default='30,40,50', help='where to decay lr, can be a list') + parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='decay rate for learning rate') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum') + parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam') + parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam') + + # model definition + parser.add_argument('--model', type=str, default='resnet50', + choices=['resnet50', 'resnet50x2', 'resnet50x4', 'hourglass', + 'resnet18', 'resnet34', 'resnet101', 'resnet152']) + parser.add_argument('--segmodel', type=str, default='fcn', + choices=['fcn', 'UNet']) + + parser.add_argument('--trained_model_path', type=str, default=None, help='pretrained backbone') + parser.add_argument('--layer', type=int, default=3, help='resnet layers') + + + # model path and name + parser.add_argument('--model_name', type=str, default="face_model") # moco_version, network, input_size, crop_size + parser.add_argument('--model_path', type=str, default="./512_faces_celeba") # path to store the models + + # resume + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + + parser.add_argument('--image_size', type=int, default=512, help='image size') # image preprocessing + parser.add_argument('--generate', action='store_true', help='generate dataset for deeplab') + parser.add_argument('--gen_path', type=str, default=None) + + # add BN + parser.add_argument('--bn', action='store_true', help='use parameter-free BN') + parser.add_argument('--cosine', action='store_true', help='use cosine annealing') + parser.add_argument('--multistep', action='store_true', help='use multistep LR') + parser.add_argument('--adam', action='store_true', help='use adam optimizer') + parser.add_argument('--amsgrad', action='store_true', help='use amsgrad for adam') + + + # GPU setting + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + + # log_path + parser.add_argument('--log_path', default='log_tmp', type=str, metavar='PATH', help='path to the log file') + + # use hypercolumn or single layer output + parser.add_argument('--use_hypercol', action='store_true', help='use hypercolumn as representations') + + opt = parser.parse_args() + + + + + opt.save_path = opt.model_path + opt.tb_path = '%s_tensorboard' % opt.model_path + + Tee(opt.log_path, 'a') + + iterations = opt.lr_decay_epochs.split(',') + opt.lr_decay_epochs = list([]) + for it in iterations: + opt.lr_decay_epochs.append(int(it)) + + opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) + if not os.path.isdir(opt.tb_folder): + os.makedirs(opt.tb_folder) + + opt.save_folder = os.path.join(opt.save_path, opt.model_name) + if not os.path.isdir(opt.save_folder): + os.makedirs(opt.save_folder) + + return opt + + +def main(): + + global best_error + best_error = np.Inf + + args = parse_option() + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + train_loader_fn = Data_Loader(img_path='./DatasetGAN_data/annotation/training_data/face_processed/', + label_path='./DatasetGAN_data/annotation/training_data/face_processed/', + image_size=args.image_size, + batch_size=args.batch_size, + mode=True) + val_loader_fn = Data_Loader(img_path='./DatasetGAN_data/annotation/testing_data/face_34_class/', + label_path='./DatasetGAN_data/annotation/testing_data/face_34_class/', + image_size=args.image_size, + batch_size=args.batch_size, + mode=False) + + + train_sampler = None + + train_loader = train_loader_fn.loader() + val_loader = val_loader_fn.loader() + + # create model and optimizer + input_size = args.image_size + pool_size = int(input_size / 2**5) + + if args.model == 'resnet50': + model = InsResNet50(pool_size=pool_size)#, pretrained=True) + desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048} + elif args.model == 'resnet50x2': + model = InsResNet50(width=2, pool_size=pool_size) + desc_dim = {1:128, 2:512, 3:1024, 4:2048, 5:4096} + elif args.model == 'resnet50x4': + model = InsResNet50(width=4, pool_size=pool_size) + desc_dim = {1:512, 2:1024, 3:2048, 4:4096, 5:8192} + elif args.model == 'resnet18': + model = InsResNet18(width=1, pool_size=pool_size) + desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512} + elif args.model == 'resnet34': + model = InsResNet34(width=1, pool_size=pool_size) + desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512} + elif args.model == 'resnet101': + model = InsResNet101(width=1, pool_size=pool_size) + desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048} + elif args.model == 'resnet152': + model = InsResNet152(width=1, pool_size=pool_size) + desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048} + elif args.model == 'hourglass': + model = HourglassNet() + else: + raise NotImplementedError('model not supported {}'.format(args.model)) + + + if args.model == 'hourglass': + feat_dim = 64 + else: + if args.use_hypercol: + feat_dim = 0 + for i in range(args.layer): + feat_dim += desc_dim[5-i] + else: + feat_dim = desc_dim[args.layer] + + if args.segmodel=='fcn': + segmentor = fcn(feat_dim, n_classes=34) + else: + segmentor = UNet(feat_dim, n_classes=34) + + + print('==> loading pre-trained model') + ckpt = torch.load(args.trained_model_path, map_location='cpu') + state_dict = ckpt['model'] + + for key in list(state_dict.keys()): + state_dict[key.replace('module.encoder', 'encoder.module')] = state_dict.pop(key) + + model.load_state_dict(state_dict, strict=False) + print('==> done') + + segmentor.init_weights() + + model = model.cuda() + segmentor = segmentor.cuda() + + if args.generate==True: + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint['model'], strict=False) + segmentor.load_state_dict(checkpoint['segmentor']) + + images_togen = [] + img_path_base = './CelebAMask-HQ/train_img/' + + for i in range(len([name for name in os.listdir(img_path_base) if os.path.isfile(os.path.join(img_path_base, name))])): + img_path = os.path.join(img_path_base, str(i)+'.jpg') + images_togen.append(img_path) + if i==10000: + break + gen_data = ImageLabelDataset(img_path_list=images_togen, + img_size=(args.image_size, args.image_size)) + if not os.path.isdir(args.gen_path): + os.mkdir(args.gen_path) + model.eval() + segmentor.eval() + gen_data = DataLoader(gen_data, batch_size=1, shuffle=False, num_workers=16) + with torch.no_grad(): + for idx, (input, im_path) in enumerate(gen_data): + input = input.cuda() + input = input.float() + # compute output + feat = model(input, args.layer, args.use_hypercol, (512,512)) + feat = feat.detach() + output = segmentor(feat) + output = output.detach() + label_out = torch.nn.functional.log_softmax(output,dim=1) + label_out = label_out.view(1, 34, 512, 512) + label = label_out[0] + label = label.data.max(0)[1].cpu().numpy() + cv2.imwrite(os.path.join(args.gen_path, str(idx) +'.png'), label) + if idx%100==0: + print('Processed '+str(idx)+'/'+str(10000)) + return + + criterion = cross_entropy2d + + if not args.adam: + optimizer = torch.optim.SGD(segmentor.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay) + else: + optimizer = torch.optim.Adam(segmentor.parameters(), + lr=args.learning_rate, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + eps=1e-8, + amsgrad=args.amsgrad) + model.eval() + cudnn.benchmark = True + + # optionally resume from a checkpoint + args.start_epoch = 1 + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location='cpu') + # checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + 1 + segmentor.load_state_dict(checkpoint['segmentor']) + optimizer.load_state_dict(checkpoint['optimizer']) + best_error = checkpoint['best_error'] + # best_error = best_error.cuda() + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + if 'opt' in checkpoint.keys(): + # resume optimization hyper-parameters + print('=> resume hyper parameters') + if 'bn' in vars(checkpoint['opt']): + print('using bn: ', checkpoint['opt'].bn) + if 'adam' in vars(checkpoint['opt']): + print('using adam: ', checkpoint['opt'].adam) + if 'cosine' in vars(checkpoint['opt']): + print('using cosine: ', checkpoint['opt'].cosine) + args.learning_rate = checkpoint['opt'].learning_rate + # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs + args.lr_decay_rate = checkpoint['opt'].lr_decay_rate + args.momentum = checkpoint['opt'].momentum + args.weight_decay = checkpoint['opt'].weight_decay + args.beta1 = checkpoint['opt'].beta1 + args.beta2 = checkpoint['opt'].beta2 + del checkpoint + torch.cuda.empty_cache() + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # set cosine annealing scheduler + if args.cosine: + + eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1 + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, -1) + # dummy loop to catch up with current epoch + for i in range(1, args.start_epoch): + scheduler.step() + elif args.multistep: + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1) + # dummy loop to catch up with current epoch + for i in range(1, args.start_epoch): + scheduler.step() + + # tensorboard + logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) + train_loss_list = [] + test_loss_list = [] + + + # routine + for epoch in range(args.start_epoch, args.epochs + 1): + + if args.cosine or args.multistep: + scheduler.step() + else: + adjust_learning_rate(epoch, args, optimizer) + print("==> training...") + + time1 = time.time() + train_loss = train(epoch, train_loader, model, segmentor, criterion, optimizer, args) + time2 = time.time() + print('train epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) + + # logger.log_value('InterOcularError', InterOcularError, epoch) + train_loss_list.append(train_loss) + logger.log_value('train_loss', train_loss, epoch) + logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) + + print("==> testing...") + test_loss = validate(val_loader, model, segmentor, criterion, args) + + test_loss_list.append(test_loss) + + # logger.log_value('Test_InterOcularError', test_InterOcularError, epoch) + logger.log_value('test_loss', test_loss, epoch) + + # save the best model + if test_loss < best_error: + best_error = test_loss + state = { + 'opt': args, + 'epoch': epoch, + 'model': model.state_dict(), + 'segmentor': segmentor.state_dict(), + 'best_error': best_error, + 'optimizer': optimizer.state_dict(), + } + save_name = '{}.pth'.format(args.model) + save_name = os.path.join(args.save_folder, save_name) + print('saving best model!') + torch.save(state, save_name) + + # save model + if epoch % args.save_freq == 0: + print('==> Saving...') + state = { + 'opt': args, + 'epoch': epoch, + 'segmentor': segmentor.state_dict(), + 'best_error': test_loss, + 'optimizer': optimizer.state_dict(), + } + save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch) + save_name = os.path.join(args.save_folder, save_name) + print('saving regular model!') + torch.save(state, save_name) + + # tensorboard logger + pass + + x=range(len(train_loss_list)) + + plt.plot(x, train_loss_list, label = "train loss") + plt.plot(x, test_loss_list, label = "test loss") + plt.xlabel('epochs') + plt.ylabel('loss') + + plt.savefig(os.path.join(args.save_folder,'loss_curve.png')) + + +def set_lr(optimizer, lr): + """ + set the learning rate + """ + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def train(epoch, train_loader, model, segmentor, criterion, optimizer, opt): + """ + one epoch training + """ + + model.eval() + segmentor.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + # InterOcularError = AverageMeter() + + end = time.time() + for idx, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + input = input.cuda(opt.gpu, non_blocking=True) + input = input.float() + target = target.cuda(opt.gpu, non_blocking=True) + + # ===================forward===================== + with torch.no_grad(): + feat = model(input, opt.layer, opt.use_hypercol, (512,512)) + feat = feat.detach() + + output = segmentor(feat) + loss = criterion(output, target) + + if idx == 0: + print('Layer:{0}, shape of input:{1}, feat:{2}, output:{3}'.format(opt.layer, + input.size(), feat.size(), output.size())) + + losses.update(loss.item(), input.size(0)) + + # ===================backward===================== + + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + # ===================meters===================== + batch_time.update(time.time() - end) + end = time.time() + + # print info + if idx % opt.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( + epoch, idx, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses))#, InterOcularError=InterOcularError)) + sys.stdout.flush() + + return losses.avg + + +def validate(val_loader, model, segmentor, criterion, opt): + batch_time = AverageMeter() + losses = AverageMeter() + + # switch to evaluate mode + model.eval() + segmentor.eval() + + with torch.no_grad(): + end = time.time() + for idx, (input, target) in enumerate(val_loader): + if opt.gpu is not None: + input = input.cuda(opt.gpu, non_blocking=True) + input = input.float() + target = target.cuda(opt.gpu, non_blocking=True) + + # compute output + feat = model(input, opt.layer, opt.use_hypercol, (512,512)) + feat = feat.detach() + + output = segmentor(feat) + loss = criterion(output, target) + + losses.update(loss.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if idx % opt.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( + idx, len(val_loader), batch_time=batch_time, loss=losses)) + + return losses.avg + + +if __name__ == '__main__': + best_error = np.Inf + main() diff --git a/gen_score_seg.py b/gen_score_seg.py new file mode 100644 index 0000000..6603586 --- /dev/null +++ b/gen_score_seg.py @@ -0,0 +1,304 @@ +import sys +sys.path.append('../') +import os +import argparse +import gc +import os +import torch +import torchvision +from torch.utils.data import Dataset, DataLoader +import glob +import json +import numpy as np +from data_loader.data_loader_celebamask import Data_Loader +from models.resnet import InsResNet50,InsResNet18,InsResNet34,InsResNet101,InsResNet152 +from models.segmentor import fcn, UNet +from torchvision import transforms + +import torch.nn.functional as F +import cv2 + + +from PIL import Image + +import scipy.misc + + +def process_image(images): + drange = [-1, 1] + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + + images = images.astype(int) + images[images > 255] = 255 + images[images < 0] = 0 + + return images.astype(int) + +class ImageLabelDataset(Dataset): + def __init__( + self, + img_path_list, + label_path_list, + img_size=(128, 128), + ): + self.img_path_list = img_path_list + self.label_path_list = label_path_list + self.img_size = img_size + + def __len__(self): + return len(self.img_path_list) + + def __getitem__(self, index): + im_path = self.img_path_list[index] + lbl_path = self.label_path_list[index] + im = Image.open(im_path) + try: + lbl = np.load(lbl_path) + except: + lbl = np.array(Image.open(lbl_path)) + if len(lbl.shape) == 3: + lbl = lbl[:, :, 0] + + lbl = Image.fromarray(lbl.astype('uint8')) + im, lbl = self.transform(im, lbl) + + return im, lbl, im_path + + def transform(self, img, lbl): + img = img.resize((512, 512)) + lbl = lbl.resize((self.img_size[0], self.img_size[1]), resample=Image.NEAREST) + lbl = torch.from_numpy(np.array(lbl)).long() + img = transforms.ToTensor()(img) + return img, lbl + + +def cross_validate(args): + + ignore_index = -1 + cp_path = args.resume + base_path = os.path.join(cp_path, "cross_validation") + if not os.path.exists(base_path): + os.mkdir(base_path) + + cps_all = glob.glob(cp_path + "/*") + cp_list = [data for data in cps_all if '.pth' in data and 'resnet' not in data] + cp_list.sort() + + ids = range(34) + + data_all = glob.glob('./DatasetGAN_data/annotation/testing_data/face_34_class/' + "/*") + images = [path for path in data_all if 'npy' not in path] + labels = [path for path in data_all if 'npy' in path] + images.sort() + labels.sort() + + + fold_num =int( len(images) / 5) + resnet_transform = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + input_size = 512 + pool_size = int(input_size / 2**5) + model = InsResNet50(pool_size=pool_size, pretrained=True) + desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048} + + feat_dim = 0 + for i in range(4): + feat_dim += desc_dim[5-i] + + if args.distill == True: + segmentor = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, progress=False, num_classes=34, aux_loss=None) + elif args.model == "UNet": + segmentor = UNet(feat_dim, n_classes=34) + elif args.model == "fcn": + segmentor = fcn(feat_dim, n_classes=34) + + cross_mIOU = [] + + for i in range(5): + val_image = images[fold_num * i: fold_num *i + fold_num] + val_label = labels[fold_num * i: fold_num *i + fold_num] + test_image = [img for img in images if img not in val_image] + test_label =[label for label in labels if label not in val_label] + print("Val Data length,", str(len(val_image))) + print("Testing Data length,", str(len(test_image))) + + val_data = ImageLabelDataset(img_path_list=val_image, + label_path_list=val_label, + img_size=(512, 512)) + val_data = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=0) + + test_data = ImageLabelDataset(img_path_list=test_image, + label_path_list=test_label, + img_size=(512, 512)) + test_data = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=0) + + best_miou = 0 + best_val_miou = 0 + + for i_ckpt, ckpt_name in enumerate(cp_list): + if i_ckpt==0 and args.distill==False: + ckpt = torch.load(os.path.join(cp_path, 'resnet50.pth')) + state_dict = ckpt['model'] + model.load_state_dict(state_dict, strict=False) + # import pdb;pdb.set_trace() + print(args) + if args.distill==False: + model.cuda() + model.eval() + ckpt_seg = torch.load(ckpt_name) + segmentor.load_state_dict(ckpt_seg['segmentor']) + else: + ckpt_seg = torch.load(ckpt_name) + segmentor.load_state_dict(ckpt_seg['model_state_dict']) + + segmentor.cuda() + segmentor.eval() + + unions = {} + intersections = {} + for target_num in ids: + unions[target_num] = 0 + intersections[target_num] = 0 + + with torch.no_grad(): + for idxx, da, in enumerate(val_data): + + img, mask = da[0], da[1] + + if img.size(1) == 4: + img = img[:, :-1, :, :] + + img = img.cuda() + mask = mask.cuda() + input_img_tensor = [] + for b in range(img.size(0)): + input_img_tensor.append(resnet_transform(img[b])) + input_img_tensor = torch.stack(input_img_tensor) + + if args.distill == False: + feat = model(input_img_tensor, 4, True, (512,512)) + feat = feat.detach() + y_pred = segmentor(feat) + else: + y_pred = segmentor(input_img_tensor)['out'] + + y_pred = torch.log_softmax(y_pred, dim=1) + _, y_pred = torch.max(y_pred, dim=1) + y_pred = y_pred.cpu().detach().numpy() + + mask = mask.cpu().detach().numpy() + bs = y_pred.shape[0] + + curr_iou = [] + if ignore_index > 0: + y_pred = y_pred * (mask != ignore_index) + for target_num in ids: + y_pred_tmp = (y_pred == target_num).astype(int) + mask_tmp = (mask == target_num).astype(int) + + intersection = (y_pred_tmp & mask_tmp).sum() + union = (y_pred_tmp | mask_tmp).sum() + + unions[target_num] += union + intersections[target_num] += intersection + + if not union == 0: + curr_iou.append(intersection / union) + mean_ious = [] + + for target_num in ids: + mean_ious.append(intersections[target_num] / (1e-8 + unions[target_num])) + mean_iou_val = np.array(mean_ious).mean() + + if mean_iou_val > best_val_miou: + best_val_miou = mean_iou_val + unions = {} + intersections = {} + for target_num in ids: + unions[target_num] = 0 + intersections[target_num] = 0 + + with torch.no_grad(): + testing_vis = [] + for idxx, da, in enumerate(test_data): + + img, mask = da[0], da[1] + + if img.size(1) == 4: + img = img[:, :-1, :, :] + + img = img.cuda() + mask = mask.cuda() + input_img_tensor = [] + for b in range(img.size(0)): + input_img_tensor.append(resnet_transform(img[b])) + input_img_tensor = torch.stack(input_img_tensor) + + if args.distill == False: + feat = model(input_img_tensor, 4, True, (512,512)) + feat = feat.detach() + y_pred = segmentor(feat) + else: + y_pred = segmentor(input_img_tensor)['out'] + + y_pred = torch.log_softmax(y_pred, dim=1) + _, y_pred = torch.max(y_pred, dim=1) + y_pred = y_pred.cpu().detach().numpy() + + + mask = mask.cpu().detach().numpy() + + curr_iou = [] + if ignore_index > 0: + y_pred = y_pred * (mask != ignore_index) + for target_num in ids: + y_pred_tmp = (y_pred == target_num).astype(int) + mask_tmp = (mask == target_num).astype(int) + + intersection = (y_pred_tmp & mask_tmp).sum() + union = (y_pred_tmp | mask_tmp).sum() + + unions[target_num] += union + intersections[target_num] += intersection + + if not union == 0: + curr_iou.append(intersection / union) + + + img = img.cpu().numpy() + img = img * 255. + img = np.transpose(img, (0, 2, 3, 1)).astype(np.uint8) + + test_mean_ious = [] + + for target_num in ids: + test_mean_ious.append(intersections[target_num] / (1e-8 + unions[target_num])) + + best_test_miou = np.array(test_mean_ious).mean() + + + print("Best IOU ,", str(best_test_miou), "CP: ", ckpt_name) + + cross_mIOU.append(best_test_miou) + + print(cross_mIOU) + print(" cross validation mean:" , np.mean(cross_mIOU) ) + print(" cross validation std:", np.std(cross_mIOU)) + result = {"Cross validation mean": np.mean(cross_mIOU), "Cross validation std": np.std(cross_mIOU), "Cross validation":cross_mIOU } + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--resume', type=str, default="") + parser.add_argument('--distill', action='store_true') + parser.add_argument('--model', type=str, default="UNet") + + args = parser.parse_args() + # import pdb;pdb.set_trace() + print(args) + + cross_validate(args) diff --git a/models/loss.py b/models/loss.py new file mode 100644 index 0000000..13a6c21 --- /dev/null +++ b/models/loss.py @@ -0,0 +1,40 @@ +import torch.nn.functional as F +import time +import torch + + +def regression_loss(prediction_normalized, kp_normalized, alpha=10., **kwargs): + kp = kp_normalized.to(prediction_normalized.device) + B, nA, _ = prediction_normalized.shape + return F.smooth_l1_loss(prediction_normalized * alpha, kp * alpha) + + +def selected_regression_loss(prediction_normalized, kp_normalized, visible, alpha=10., **kwargs): + kp = kp_normalized.to(prediction_normalized.device) + B, nA, _ = prediction_normalized.shape + for i in range(B): + vis = visible[i] + invis = [not v for v in vis] + kp[i][invis] = 0. + prediction_normalized[i][invis] = 0. + + return F.smooth_l1_loss(prediction_normalized * alpha, kp * alpha) + +def cross_entropy2d(input, target, weight=None, size_average=True): + # import pdb;pdb.set_trace() + # input = input.permute(0,2,1).reshape(-1,34,512,512) + n, c, h, w = input.size() + target = target.squeeze(1).long() + nt, ht, wt = target.size() + + # Handle inconsistent size between input and target + if h != ht or w != wt: + print('whys the sizes different') + input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) + + input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + target = target.view(-1) + loss = F.cross_entropy( + input, target, weight=weight, size_average=size_average, ignore_index=250 + ) + return loss diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..a5dd010 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,376 @@ +import torch +import torch.nn as nn +import math +import numpy as np +import torch.utils.model_zoo as model_zoo + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class Normalize(nn.Module): + + def __init__(self, power=2): + super(Normalize, self).__init__() + self.power = power + + def forward(self, x): + norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) + out = x.div(norm) + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, low_dim=128, in_channel=3, width=1, pool_size=7): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=1, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + + self.base = int(64 * width) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, self.base, layers[0]) + self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) + self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) + self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(pool_size, stride=1) + # self.fc = nn.Linear(self.base * 8 * block.expansion, low_dim) + self.l2norm = Normalize(2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48), concat=True): + if layer <= 0: + return x + x1 = self.conv1(x) + x2 = self.bn1(x1) + x3 = self.relu(x2) + x4 = self.maxpool(x3) + x5 = self.layer1(x4) + x6 = self.layer2(x5) + x7 = self.layer3(x6) + x8 = self.layer4(x7) + x9 = self.avgpool(x8) + x10 = x9.view(x9.size(0), -1) + # x11 = self.fc(x10) + # x12 = self.l2norm(x11) + + feat = {1:x4, 2:x5, 3:x6, 4:x7, 5:x8, 6:x10}#, 7:x12} + + if use_hypercol: + # hypercols = [x8, x7, x6, x5, x1] + hypercols = [x8, x7, x6, x5, x4] + hypercols = hypercols[:layer] + if concat==False: + return hypercols + for index, feat in enumerate(hypercols): + hypercols[index] = nn.functional.interpolate(feat, output_shape, mode='bilinear') + return torch.cat(hypercols, dim=1) + else: + return feat[layer] + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model + + +class InsResNet50(nn.Module): + """Encoder for instance discrimination and MoCo""" + def __init__(self, width=1, pool_size=7, pretrained=False): + super(InsResNet50, self).__init__() + self.encoder = resnet50(width=width, pool_size=pool_size, pretrained=pretrained) + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48)): + return self.encoder(x, layer, use_hypercol, output_shape) + + +class InsResNet18(nn.Module): + """Encoder for instance discrimination and MoCo""" + def __init__(self, width=1, pool_size=7): + super(InsResNet18, self).__init__() + self.encoder = resnet18(width=width, pool_size=pool_size) + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48)): + return self.encoder(x, layer, use_hypercol, output_shape) + + +class InsResNet34(nn.Module): + """Encoder for instance discrimination and MoCo""" + def __init__(self, width=1, pool_size=7): + super(InsResNet34, self).__init__() + self.encoder = resnet34(width=width, pool_size=pool_size) + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48)): + return self.encoder(x, layer, use_hypercol, output_shape) + + +class InsResNet101(nn.Module): + """Encoder for instance discrimination and MoCo""" + def __init__(self, width=1, pool_size=7): + super(InsResNet101, self).__init__() + self.encoder = resnet101(width=width, pool_size=pool_size) + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48)): + return self.encoder(x, layer, use_hypercol, output_shape) + + +class InsResNet152(nn.Module): + """Encoder for instance discrimination and MoCo""" + def __init__(self, width=1, pool_size=7): + super(InsResNet152, self).__init__() + self.encoder = resnet152(width=width, pool_size=pool_size) + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7, use_hypercol=False, output_shape=(48,48)): + return self.encoder(x, layer, use_hypercol, output_shape) + + +class ResNetV1(nn.Module): + def __init__(self, name='resnet50'): + super(ResNetV1, self).__init__() + if name == 'resnet50': + self.l_to_ab = resnet50(in_channel=1, width=0.5) + self.ab_to_l = resnet50(in_channel=2, width=0.5) + elif name == 'resnet18': + self.l_to_ab = resnet18(in_channel=1, width=0.5) + self.ab_to_l = resnet18(in_channel=2, width=0.5) + elif name == 'resnet101': + self.l_to_ab = resnet101(in_channel=1, width=0.5) + self.ab_to_l = resnet101(in_channel=2, width=0.5) + else: + raise NotImplementedError('model {} is not implemented'.format(name)) + + def forward(self, x, layer=7): + l, ab = torch.split(x, [1, 2], dim=1) + feat_l = self.l_to_ab(l, layer) + feat_ab = self.ab_to_l(ab, layer) + return feat_l, feat_ab + + +class ResNetV2(nn.Module): + def __init__(self, name='resnet50'): + super(ResNetV2, self).__init__() + if name == 'resnet50': + self.l_to_ab = resnet50(in_channel=1, width=1) + self.ab_to_l = resnet50(in_channel=2, width=1) + elif name == 'resnet18': + self.l_to_ab = resnet18(in_channel=1, width=1) + self.ab_to_l = resnet18(in_channel=2, width=1) + elif name == 'resnet101': + self.l_to_ab = resnet101(in_channel=1, width=1) + self.ab_to_l = resnet101(in_channel=2, width=1) + else: + raise NotImplementedError('model {} is not implemented'.format(name)) + + def forward(self, x, layer=7): + l, ab = torch.split(x, [1, 2], dim=1) + feat_l = self.l_to_ab(l, layer) + feat_ab = self.ab_to_l(ab, layer) + return feat_l, feat_ab + + +class ResNetV3(nn.Module): + def __init__(self, name='resnet50'): + super(ResNetV3, self).__init__() + if name == 'resnet50': + self.l_to_ab = resnet50(in_channel=1, width=2) + self.ab_to_l = resnet50(in_channel=2, width=2) + elif name == 'resnet18': + self.l_to_ab = resnet18(in_channel=1, width=2) + self.ab_to_l = resnet18(in_channel=2, width=2) + elif name == 'resnet101': + self.l_to_ab = resnet101(in_channel=1, width=2) + self.ab_to_l = resnet101(in_channel=2, width=2) + else: + raise NotImplementedError('model {} is not implemented'.format(name)) + + def forward(self, x, layer=7): + l, ab = torch.split(x, [1, 2], dim=1) + feat_l = self.l_to_ab(l, layer) + feat_ab = self.ab_to_l(ab, layer) + return feat_l, feat_ab + + +class MyResNetsCMC(nn.Module): + def __init__(self, name='resnet50v1'): + super(MyResNetsCMC, self).__init__() + if name.endswith('v1'): + self.encoder = ResNetV1(name[:-2]) + elif name.endswith('v2'): + self.encoder = ResNetV2(name[:-2]) + elif name.endswith('v3'): + self.encoder = ResNetV3(name[:-2]) + else: + raise NotImplementedError('model not support: {}'.format(name)) + + self.encoder = nn.DataParallel(self.encoder) + + def forward(self, x, layer=7): + return self.encoder(x, layer) diff --git a/models/segmentor.py b/models/segmentor.py new file mode 100644 index 0000000..6939413 --- /dev/null +++ b/models/segmentor.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .unet_parts import * + + +class UpBlock(nn.Module): + def __init__(self, inplanes, planes, upsample=False): + super(UpBlock, self).__init__() + self.conv = nn.Conv2d(inplanes, planes, 1, 1) + self.bn = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.will_ups = upsample + + def forward(self, x): + if self.will_ups: + x = nn.functional.upsample(x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes=34, bilinear=True): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, int(1024)) + self.down1 = Down(int(1024), int(256)) + self.down2 = Down(int(256), int(256)) + self.down3 = Down(int(256), int(512/2)) + self.down4 = Down(int(512/2), int(512)) + factor = 2 if bilinear else 1 + self.down5 = Down(int(512), int(1024) // factor) + self.up0 = Up(int(1024), int(512) // factor, bilinear) + self.up1 = Up(int(512), int(512) // factor, bilinear) + self.up2 = Up(int(512), int(256) // factor, bilinear) + self.up3 = Up(int(384), int(512) // factor, bilinear) + self.up4 = Up(int(1280), int(256), bilinear) + self.outc = OutConv(int(256), n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + x = self.up0(x6, x5) + + x = self.up1(x, x4) + + x = self.up2(x, x3) + x = self.up3(x, x2) + + x = self.up4(x, x1) + logits = self.outc(x) + return logits + + def init_weights(self, init_type='kaiming', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + + +class fcn(nn.Module): + def __init__(self, descriptor_dimension, n_classes=34): + super().__init__() + self.decoder = nn.Sequential(nn.Conv2d( + in_channels=descriptor_dimension, + out_channels=1024, + kernel_size=1, + padding=0, + bias=False), + nn.BatchNorm2d(1024), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=1024, + out_channels=256, + kernel_size=1, + padding=0, + bias=False), + nn.BatchNorm2d(256), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=256, + out_channels=n_classes, + kernel_size=1, + padding=0, + bias=False)) + + + def forward(self, input): + + out = self.decoder(input) + return out + + def init_weights(self, init_type='kaiming', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + diff --git a/models/unet_parts.py b/models/unet_parts.py new file mode 100644 index 0000000..c9a3921 --- /dev/null +++ b/models/unet_parts.py @@ -0,0 +1,80 @@ +""" Parts of the U-Net model """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + # nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + # nn.BatchNorm2d(out_channels), + # nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=False): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + # self.pixconv = DoubleConv(in_channels, in_channels*2) + # self.up = nn.PixelShuffle(3) + self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + # x1 = self.pixconv(x1) + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..891b1c6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +matplotlib==3.3.4 +numpy==1.18.1 +opencv-python==4.5.3.56 +pandas==1.1.5 +Pillow==8.3.2 +pyparsing==2.4.7 +python-dateutil==2.8.2 +scikit-image==0.17.2 +scipy==1.1.0 +tensorboard-logger==0.1.0 +torch==1.7.0 +torchvision==0.8.0 \ No newline at end of file diff --git a/train_deeplab_contrast.py b/train_deeplab_contrast.py new file mode 100644 index 0000000..4274209 --- /dev/null +++ b/train_deeplab_contrast.py @@ -0,0 +1,185 @@ +""" +Copyright (C) 2021 NVIDIA Corporation. All rights reserved. +Licensed under The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +import sys +sys.path.append('../') +import os +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +import argparse +import gc +import os +import torch +import torchvision +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +import glob +from torchvision import transforms +from PIL import Image +import numpy as np +# from utils.data_util import * +import json +import pickle + +class ImageLabelDataset(Dataset): + def __init__( + self, + img_path_list, + label_path_list, + img_size=(128, 128), + ): + self.img_path_list = img_path_list + self.label_path_list = label_path_list + self.img_size = img_size + + def __len__(self): + return len(self.img_path_list) + + def __getitem__(self, index): + im_path = self.img_path_list[index] + lbl_path = self.label_path_list[index] + im = Image.open(im_path) + try: + lbl = np.load(lbl_path) + except: + lbl = np.array(Image.open(lbl_path)) + if len(lbl.shape) == 3: + lbl = lbl[:, :, 0] + + lbl = Image.fromarray(lbl.astype('uint8')) + im, lbl = self.transform(im, lbl) + + return im, lbl, im_path + + def transform(self, img, lbl): + img = img.resize((self.img_size[0], self.img_size[1])) + lbl = lbl.resize((self.img_size[0], self.img_size[1]), resample=Image.NEAREST) + lbl = torch.from_numpy(np.array(lbl)).long() + img = transforms.ToTensor()(img) + return img, lbl + + +def main(data_path, exp_dir, image_size, resume, num_classes): + + base_path = os.path.join(exp_dir, "deeplab_class_%d_checkpoint" %(num_classes)) + if not os.path.exists(base_path): + os.mkdir(base_path) + print("Model dir,", base_path) + + stylegan_images = [] + stylegan_labels = [] + + img_path_base = './CelebAMask-HQ/train_img/' + lbl_path_base = data_path + + for i in range(len([name for name in os.listdir(img_path_base) if os.path.isfile(os.path.join(img_path_base, name))])): + img_path = os.path.join(img_path_base, str(i)+'.jpg') + label_path = os.path.join(lbl_path_base, str(i)+'_label.png') + stylegan_images.append(img_path) + stylegan_labels.append(label_path) + if i==10000: + break + + assert len(stylegan_images) == len(stylegan_labels) + print( "Train data length,", str(len(stylegan_labels))) + + train_data = ImageLabelDataset(img_path_list=stylegan_images, + label_path_list=stylegan_labels, + img_size=(image_size, image_size)) + + train_data = DataLoader(train_data, batch_size=8, shuffle=True, num_workers=16) + classifier = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, progress=False, + num_classes=num_classes, aux_loss=None) + + if resume != "": + checkpoint = torch.load(resume) + classifier.load_state_dict(checkpoint['model_state_dict']) + + classifier.cuda() + classifier.train() + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(classifier.parameters(), lr=0.001) + + resnet_transform = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + + for epoch in range(5): + for i, da, in enumerate(train_data): + if da[0].shape[0] != 8: + continue + if i % 10 == 0: + gc.collect() + + classifier.train() + + optimizer.zero_grad() + img, mask = da[0], da[1] + + img = img.cuda() + mask = mask.cuda() + + input_img_tensor = [] + for b in range(img.size(0)): + if img.size(1) == 4: + input_img_tensor.append(resnet_transform(img[b][:-1,:,:])) + else: + input_img_tensor.append(resnet_transform(img[b])) + + input_img_tensor = torch.stack(input_img_tensor) + + y_pred = classifier(input_img_tensor)['out'] + loss = criterion(y_pred, mask) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(epoch, 'epoch', 'iteration', i, 'loss', loss.item()) + + model_path = os.path.join(base_path, 'deeplab_epoch_' + str(epoch) + '.pth') + + print('Save to:', model_path) + torch.save({'model_state_dict': classifier.state_dict()}, + model_path) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str) + parser.add_argument('--model_path', type=str, default="./512_faces_celeba_distilled") # path to store the models + parser.add_argument('--image_size', type=int, default=512, help='image size') # image preprocessing + parser.add_argument('--resume', type=str, default="") + parser.add_argument('--num_classes', type=int, default=34) + + args = parser.parse_args() + + path = args.model_path + if os.path.exists(path): + pass + else: + os.system('mkdir -p %s' % (path)) + print('Experiment folder created at: %s' % (path)) + + main(args.data_path, args.model_path, args.image_size, args.resume, args.num_classes) + diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..bab79c1 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,256 @@ +import os +import torch +import json +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +import math + +import sys +from pathlib import Path +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # NOQA + +##################################################### +# From DVE: https://github.com/jamt9000/DVE +##################################################### + +def label_colormap(x): + colors = np.array([ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + ]) + ndim = len(x.shape) + num_classes = 11 + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + + r = x.clone().float() + g = x.clone().float() + b = x.clone().float() + if ndim == 2: + rgb = torch.zeros((x.shape[0], x.shape[1], 3)) + else: + rgb = torch.zeros((x.shape[0], 3, x.shape[2], x.shape[3])) + colors = torch.from_numpy(colors) + label_colours = dict(zip(range(num_classes), colors)) + + for l in range(0, num_classes): + r[x == l] = label_colours[l][0] + g[x == l] = label_colours[l][1] + b[x == l] = label_colours[l][2] + if ndim == 2: + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + elif ndim == 4: + rgb[:, 0, None] = r / 255.0 + rgb[:, 1, None] = g / 255.0 + rgb[:, 2, None] = b / 255.0 + else: + import ipdb; + ipdb.set_trace() + return rgb + + +def ensure_dir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == 'module.': + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict + + +def get_instance(module, name, config, *args, **kwargs): + return getattr(module, config[name]['type'])(*args, **config[name]['args'], + **kwargs) + + +def coll(batch): + b = torch.utils.data.dataloader.default_collate(batch) + # Flatten to be 4D + return [ + bi.reshape((-1,) + bi.shape[-3:]) if isinstance(bi, torch.Tensor) else bi + for bi in b + ] + + +def dict_coll(batch): + cb = torch.utils.data.dataloader.default_collate(batch) + cb["data"] = cb["data"].reshape((-1,) + cb["data"].shape[-3:]) # Flatten to be 4D + if False: + from torchvision.utils import make_grid + from utils.visualization import norm_range + ims = norm_range(make_grid(cb["data"])).permute(1, 2, 0).cpu().numpy() + plt.imshow(ims) + return cb + + +# def dict_coll(batch): +# b = torch.utils.data.dataloader.default_collate(batch) +# # Flatten to be 4D +# return [ +# bi.reshape((-1, ) + bi.shape[-3:]) if isinstance(bi, torch.Tensor) else bi +# for bi in b +# ] + + +class NoGradWrapper(nn.Module): + def __init__(self, wrapped): + super(NoGradWrapper, self).__init__() + self.wrapped_module = wrapped + + def forward(self, *args, **kwargs): + with torch.no_grad(): + return self.wrapped_module.forward(*args, **kwargs) + + +class Up(nn.Module): + def forward(self, x): + with torch.no_grad(): + return [F.interpolate(x[0], scale_factor=2, mode='bilinear', align_corners=False)] + + +def read_json(fname): + with fname.open('rt') as handle: + return json.load(handle, object_hook=OrderedDict) + + +def write_json(content, fname): + with fname.open('wt') as handle: + json.dump(content, handle, indent=4, sort_keys=False) + + +def pad_and_crop(im, rr): + """Return im[rr[0]:rr[1],rr[2]:rr[3]] + + Pads if necessary to allow out of bounds indexing + """ + + meanval = np.array(np.dstack((0, 0, 0)), dtype=im.dtype) + + if rr[0] < 0: + top = -rr[0] + P = np.tile(meanval, [top, im.shape[1], 1]) + im = np.vstack([P, im]) + rr[0] = rr[0] + top + rr[1] = rr[1] + top + + if rr[2] < 0: + left = -rr[2] + P = np.tile(meanval, [im.shape[0], left, 1]) + im = np.hstack([P, im]) + rr[2] = rr[2] + left + rr[3] = rr[3] + left + + if rr[1] > im.shape[0]: + bottom = rr[1] - im.shape[0] + P = np.tile(meanval, [bottom, im.shape[1], 1]) + im = np.vstack([im, P]) + + if rr[3] > im.shape[1]: + right = rr[3] - im.shape[1] + P = np.tile(meanval, [im.shape[0], right, 1]) + im = np.hstack([im, P]) + + im = im[rr[0]:rr[1], rr[2]:rr[3]] + + return im + +##################################################### +# From CMC: https://github.com/HobbitLong/CMC +##################################################### + +def adjust_learning_rate(epoch, opt, optimizer): + """Sets the learning rate to the initial LR decayed by 0.2 every steep step""" + if opt.cosine: + new_lr = opt.learning_rate * 0.5 * (1. + math.cos(math.pi * epoch / opt.epochs)) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + else: + steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) + if steps > 0: + new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + + +def reset_learning_rate(optimizer, learning_rate): + """Sets the learning rate to the initial LR decayed by 0.2 every steep step""" + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class Tee(object): + def __init__(self, name, mode): + self.file = open(name, mode) + self.stdout = sys.stdout + sys.stdout = self + def __del__(self): + sys.stdout = self.stdout + self.file.close() + def write(self, data): + self.file.write(data) + self.stdout.write(data) + def flush(self): + self.file.flush() + +