-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Oindrila Saha
committed
Apr 4, 2022
1 parent
2509725
commit 64b8499
Showing
12 changed files
with
2,137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Code for GANorCON | ||
|
||
This is the code for the contrastive few-shot part segmentation method proposed in 'GANorCON : Are Generative Models Useful for Few-shot Segmentation?'. We provide the pretrained MocoV2 model for ease. | ||
|
||
|
||
## Preparation | ||
``` | ||
pip3 install requirements.txt | ||
``` | ||
-> Download MoCoV2_512_CelebA.pth from https://drive.google.com/file/d/1n0iyFuZ20s_DsIAorvmtVLIdHIG_65N0/ and place inside this folder | ||
-> Place the [DatasetGAN released data](https://drive.google.com/drive/folders/1PSS0uOusN3dV84YLT9Gds1ZSugjpMz7E) at ./DatasetGAN_data inside this folder | ||
-> Download CelebAMask from [here](https://drive.google.com/open?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv) and place inside this folder | ||
|
||
## Training - Few Shot Segmentation | ||
|
||
```bash | ||
|
||
python3 eval_face_seg.py --model resnet50 --segmodel fcn --layer 4 --trained_model_path MoCoV2_512_CelebA.pth --learning_rate 0.001 --weight_decay 0.0005 --adam --epochs 800 --cosine --batch_size 1 --log_path ./log.txt --model_name face_segmentor --model_path ./512_faces_celeba --image_size 512 --use_hypercol | ||
|
||
``` | ||
Option --segmodel can be set to either "fcn" or "UNet" for either variants described in the paper. --model_path can be set to desired location for saving the checkpoints. | ||
|
||
## Generate data for distillation | ||
|
||
```bash | ||
|
||
python3 eval_face_seg.py --model resnet50 --segmodel fcn --layer 4 --trained_model_path MoCoV2_512_CelebA.pth --image_size 512 --use_hypercol --generate --gen_path ./labels_fordeeplab/ --resume ./512_faces_celeba/face_segmentor/resnet50.pth | ||
|
||
``` | ||
Place path to the trained model resnet50.pth in --resume. Option --gen_path is where the generated predicted labels using checkpoint in --resume will be stored. | ||
|
||
## Distillation | ||
|
||
```bash | ||
|
||
python3 train_deeplab_contrast.py --data_path ./labels_fordeeplab/ --model_path ./512_faces_celeba_distilled --image_size 512 --num_classes 34 | ||
|
||
``` | ||
Specify the path to generated labels from previous step in --data_path and specify path to save model in --model_path. | ||
|
||
## Testing | ||
|
||
For model from Few Shot Segmentation training: | ||
```bash | ||
|
||
python3 gen_score_seg.py --resume ./512_faces_celeba/Nvidia_segmentor/ --model fcn | ||
|
||
``` | ||
--model can be also changed to UNet if UNet based segmentor was used to train. | ||
|
||
|
||
For model from distillation: | ||
```bash | ||
|
||
python3 gen_score_seg.py --resume ./512_faces_celeba_distilled/deeplab_class_34_checkpoint/ --distill | ||
|
||
``` | ||
|
||
Place the folder where all checkpoints are stored in --resume for both cases. | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import torch | ||
import torchvision.datasets as dsets | ||
from torchvision import transforms | ||
import torchvision.transforms.functional as TF | ||
from PIL import Image | ||
import os | ||
import numpy as np | ||
import random | ||
|
||
class CelebAMaskHQ(): | ||
def __init__(self, img_path, label_path, transform_img, transform_label, mode): | ||
self.img_path = img_path | ||
self.label_path = label_path | ||
self.transform_img = transform_img | ||
self.transform_label = transform_label | ||
self.train_dataset = [] | ||
self.test_dataset = [] | ||
self.mode = mode | ||
self.preprocess_nvidia() | ||
|
||
if mode == True: | ||
self.num_images = len(self.train_dataset) | ||
else: | ||
self.num_images = len(self.test_dataset) | ||
|
||
def preprocess_nvidia(self): | ||
if self.mode==True: | ||
for i in range(int(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])/2)): | ||
img_path = os.path.join(self.img_path, 'image_'+str(i)+'.jpg') | ||
label_path = os.path.join(self.label_path, 'image_mask'+str(i)+'.npy') | ||
self.train_dataset.append([img_path, label_path]) | ||
else: | ||
for i in range(int(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])/2)): | ||
img_path = os.path.join(self.img_path, 'face_'+str(i)+'.png') | ||
label_path = os.path.join(self.label_path, 'mask_'+str(i)+'.npy') | ||
self.test_dataset.append([img_path, label_path]) | ||
|
||
print('Finished preprocessing the Nvidia dataset...') | ||
|
||
def __getitem__(self, index): | ||
|
||
dataset = self.train_dataset if self.mode == True else self.test_dataset | ||
img_path, label_path = dataset[index] | ||
image = Image.open(img_path) | ||
label = np.load(label_path) | ||
|
||
label = Image.fromarray(label) | ||
image = image.resize((512,512)) | ||
label = label.resize((512, 512), resample=Image.NEAREST) | ||
|
||
crop = random.random() < 0.5 | ||
if crop and self.mode==True: | ||
i, j, h, w = transforms.RandomResizedCrop.get_params( | ||
image, scale=(0.6,1.0), ratio=(0.7,1.3)) | ||
|
||
image = TF.crop(image, i, j, h, w) | ||
label = TF.crop(label, i, j, h, w) | ||
|
||
image = image.resize((512,512)) | ||
label = label.resize((512, 512), resample=Image.NEAREST) | ||
|
||
jitter = random.random() < 0.5 | ||
if jitter and self.mode==True: | ||
image = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0)(image) | ||
|
||
hflip = random.random() < 0.5 | ||
if hflip and self.mode==True: | ||
image = image.transpose(Image.FLIP_LEFT_RIGHT) | ||
label = label.transpose(Image.FLIP_LEFT_RIGHT) | ||
label = np.array(label, dtype=np.long) | ||
|
||
return self.transform_img(image), self.transform_label(label) | ||
|
||
def __len__(self): | ||
"""Return the number of images.""" | ||
return self.num_images | ||
|
||
class Data_Loader(): | ||
def __init__(self, img_path, label_path, image_size, batch_size, mode): | ||
self.img_path = img_path | ||
self.label_path = label_path | ||
self.imsize = image_size | ||
self.batch = batch_size | ||
self.mode = mode | ||
|
||
def transform_img(self, resize, totensor, normalize, centercrop): | ||
options = [] | ||
if centercrop: | ||
options.append(transforms.CenterCrop(160)) | ||
if resize: | ||
options.append(transforms.Resize((self.imsize,self.imsize))) | ||
if totensor: | ||
options.append(transforms.ToTensor()) | ||
if normalize: | ||
options.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225])) | ||
transform = transforms.Compose(options) | ||
return transform | ||
|
||
def transform_label(self, resize, totensor, normalize, centercrop): | ||
options = [] | ||
if centercrop: | ||
options.append(transforms.CenterCrop(160)) | ||
if resize: | ||
options.append(transforms.Resize((self.imsize,self.imsize))) | ||
if totensor: | ||
options.append(transforms.ToTensor()) | ||
if normalize: | ||
options.append(transforms.Normalize((0, 0, 0), (0, 0, 0))) | ||
transform = transforms.Compose(options) | ||
return transform | ||
|
||
def loader(self): | ||
transform_img = self.transform_img(True, True, True, False) | ||
transform_label = self.transform_label(False, True, False, False) | ||
dataset = CelebAMaskHQ(self.img_path, self.label_path, transform_img, transform_label, self.mode) | ||
|
||
print(len(dataset)) | ||
|
||
loader = torch.utils.data.DataLoader(dataset=dataset, | ||
batch_size=self.batch, | ||
shuffle=False,#self.mode==True, | ||
num_workers=0, | ||
drop_last=False, | ||
pin_memory=True) | ||
return loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from PIL import Image | ||
import torchvision | ||
from torch.utils.data import Dataset | ||
|
||
resnet_transform = torchvision.transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
class ImageLabelDataset(Dataset): | ||
def __init__( | ||
self, | ||
img_path_list, | ||
img_size=(128, 128), | ||
): | ||
self.img_path_list = img_path_list | ||
self.img_size = img_size | ||
|
||
def __len__(self): | ||
return len(self.img_path_list) | ||
|
||
def __getitem__(self, index): | ||
im_path = self.img_path_list[index] | ||
im = Image.open(im_path) | ||
im = self.transform(im) | ||
return im, im_path | ||
|
||
def transform(self, img): | ||
img = img.resize((self.img_size[0], self.img_size[1])) | ||
img = torchvision.transforms.ToTensor()(img) | ||
img = resnet_transform(img) | ||
return img |
Oops, something went wrong.