diff --git a/README.md b/README.md index 1830d48..71861f8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,111 @@ -# Graphonomy-Universal-Human-Parsing-via-Graph-Transfer-Learning -coming soon. +# Graphonomy: Universal Human Parsing via Graph Transfer Learning + +This repository contains the code for the paper: + +[**Graphonomy: Universal Human Parsing via Graph Transfer Learning**](https://arxiv.org/abs/1904.04536) +,Ke Gong, Yiming Gao, Xiaodan Liang, Xiaohui Shen, Meng Wang, Liang Lin. + +# Environment and installation ++ Pytorch = 0.4.0 ++ torchvision ++ scipy ++ tensorboardX ++ numpy ++ opencv-python ++ matplotlib ++ networkx + + you can install above package by using `pip install -r requirements.txt` + +# Getting Started +### Data Preparation ++ You need to download the human parsing dataset, prepare the images and store in `/data/datasets/dataset_name/`. +We recommend to symlink the path to the dataets to `/data/dataset/` as follows + +``` +# symlink the Pascal-Person-Part dataset for example +ln -s /path_to_Pascal_Person_Part/* data/datasets/pascal/ +``` ++ The file structure should look like: +``` +/Graphonomy + /data + /datasets + /pascal + /JPEGImages + /list + /SegmentationPart + /CIHP_4w + /Images + /lists + ... +``` + +### Inference +We provide a simply script to get the visualization result on the CIHP dataset using [trained](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing) + models as follows : +```shell +# Example of inference +python exp/inference/inference.py \ +--loadmodel /path_to_inference_model \ +--img_path ./img/messi.jpg \ +--output_path ./img/ \ +--output_name /output_file_name +``` + +### Training +#### Transfer learning +1. Download the Pascal pretrained model(avaliable soon). +2. Run the `sh train_transfer_cihp.sh`. +3. The results and models are saved in exp/transfer/run/. +4. Evaluation and visualization script is eval_cihp.sh. You only need to change the attribute of `--loadmodel` before you run it. + +#### Universal training +1. Download the [pretrained](https://drive.google.com/file/d/18WiffKnxaJo50sCC9zroNyHjcnTxGCbk/view?usp=sharing) model and store in /data/pretrained_model/. +2. Run the `sh train_universal.sh`. +3. The results and models are saved in exp/universal/run/. + +### Testing +If you want to evaluate the performance of a pre-trained model on PASCAL-Person-Part or CIHP val/test set, +simply run the script: `sh eval_cihp/pascal.sh`. +Specify the specific model. And we provide the final model that you can download and store it in /data/pretrained_model/. + +### Models +**Pascal-Person-Part trained model** + +|Model|Google Cloud|Baidu Yun| +|--------|--------------|-----------| +|Graphonomy(CIHP)| [Download](https://drive.google.com/file/d/1cwEhlYEzC7jIShENNLnbmcBR0SNlZDE6/view?usp=sharing)| Avaliable soon| + +**CIHP trained model** + +|Model|Google Cloud|Baidu Yun| +|--------|--------------|-----------| +|Graphonomy(PASCAL)| [Download](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing)| Avaliable soon| + +**Universal trained model** + +|Model|Google Cloud|Baidu Yun| +|--------|--------------|-----------| +|Universal|Avaliable soon|Avaliable soon| + +### Todo: +- [ ] release pretrained and trained models +- [ ] update universal eval code&script + +# Citation + +``` +@inproceedings{Gong2019Graphonomy, +author = {Ke Gong and Yiming Gao and Xiaodan Liang and Xiaohui Shen and Meng Wang and Liang Lin}, +title = {Graphonomy: Universal Human Parsing via Graph Transfer Learning}, +booktitle = {CVPR}, +year = {2019}, +} + +``` + +# Contact +if you have any questions about this repo, please feel free to contact +[gaoym9@mail2.sysu.edu.cn](mailto:gaoym9@mail2.sysu.edu.cn). + diff --git a/dataloaders/__init__.py b/dataloaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloaders/atr.py b/dataloaders/atr.py new file mode 100644 index 0000000..525ac08 --- /dev/null +++ b/dataloaders/atr.py @@ -0,0 +1,109 @@ +from __future__ import print_function, division +import os +from PIL import Image +from torch.utils.data import Dataset +from .mypath_atr import Path +import random +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class VOCSegmentation(Dataset): + """ + ATR dataset + """ + + def __init__(self, + base_dir=Path.db_root_dir('atr'), + split='train', + transform=None, + flip=False, + ): + """ + :param base_dir: path to ATR dataset directory + :param split: train/val + :param transform: transform to apply + """ + super(VOCSegmentation).__init__() + self._flip_flag = flip + + self._base_dir = base_dir + self._image_dir = os.path.join(self._base_dir, 'JPEGImages') + self._cat_dir = os.path.join(self._base_dir, 'SegmentationClassAug') + self._flip_dir = os.path.join(self._base_dir,'SegmentationClassAug_rev') + + if isinstance(split, str): + self.split = [split] + else: + split.sort() + self.split = split + + self.transform = transform + + _splits_dir = os.path.join(self._base_dir, 'list') + + self.im_ids = [] + self.images = [] + self.categories = [] + self.flip_categories = [] + + for splt in self.split: + with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + + for ii, line in enumerate(lines): + + _image = os.path.join(self._image_dir, line+'.jpg' ) + _cat = os.path.join(self._cat_dir, line +'.png') + _flip = os.path.join(self._flip_dir,line + '.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + assert os.path.isfile(_flip) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + self.flip_categories.append(_flip) + + + assert (len(self.images) == len(self.categories)) + assert len(self.flip_categories) == len(self.categories) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + _img, _target= self._make_img_gt_point_pair(index) + sample = {'image': _img, 'label': _target} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) + # _target = np.array(Image.open(self.categories[index])).astype(np.float32) + + _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic + if self._flip_flag: + if random.random() < 0.5: + _target = Image.open(self.flip_categories[index]) + _img = _img.transpose(Image.FLIP_LEFT_RIGHT) + else: + _target = Image.open(self.categories[index]) + else: + _target = Image.open(self.categories[index]) + + return _img, _target + + def __str__(self): + return 'ATR(split=' + str(self.split) + ')' + + + diff --git a/dataloaders/cihp.py b/dataloaders/cihp.py new file mode 100644 index 0000000..cc1722f --- /dev/null +++ b/dataloaders/cihp.py @@ -0,0 +1,107 @@ +from __future__ import print_function, division +import os +from PIL import Image +from torch.utils.data import Dataset +from .mypath_cihp import Path +import random + +class VOCSegmentation(Dataset): + """ + CIHP dataset + """ + + def __init__(self, + base_dir=Path.db_root_dir('cihp'), + split='train', + transform=None, + flip=False, + ): + """ + :param base_dir: path to CIHP dataset directory + :param split: train/val/test + :param transform: transform to apply + """ + super(VOCSegmentation).__init__() + self._flip_flag = flip + + self._base_dir = base_dir + self._image_dir = os.path.join(self._base_dir, 'Images') + self._cat_dir = os.path.join(self._base_dir, 'Category_ids') + self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids') + + if isinstance(split, str): + self.split = [split] + else: + split.sort() + self.split = split + + self.transform = transform + + _splits_dir = os.path.join(self._base_dir, 'lists') + + self.im_ids = [] + self.images = [] + self.categories = [] + self.flip_categories = [] + + for splt in self.split: + with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + + for ii, line in enumerate(lines): + + _image = os.path.join(self._image_dir, line+'.jpg' ) + _cat = os.path.join(self._cat_dir, line +'.png') + _flip = os.path.join(self._flip_dir,line + '.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + assert os.path.isfile(_flip) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + self.flip_categories.append(_flip) + + + assert (len(self.images) == len(self.categories)) + assert len(self.flip_categories) == len(self.categories) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + _img, _target= self._make_img_gt_point_pair(index) + sample = {'image': _img, 'label': _target} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) + # _target = np.array(Image.open(self.categories[index])).astype(np.float32) + + _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic + if self._flip_flag: + if random.random() < 0.5: + _target = Image.open(self.flip_categories[index]) + _img = _img.transpose(Image.FLIP_LEFT_RIGHT) + else: + _target = Image.open(self.categories[index]) + else: + _target = Image.open(self.categories[index]) + + return _img, _target + + def __str__(self): + return 'CIHP(split=' + str(self.split) + ')' + + + diff --git a/dataloaders/cihp_pascal_atr.py b/dataloaders/cihp_pascal_atr.py new file mode 100644 index 0000000..bf35b74 --- /dev/null +++ b/dataloaders/cihp_pascal_atr.py @@ -0,0 +1,219 @@ +from __future__ import print_function, division +import os +from PIL import Image +import numpy as np +from torch.utils.data import Dataset +from .mypath_cihp import Path +from .mypath_pascal import Path as PP +from .mypath_atr import Path as PA +import random +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class VOCSegmentation(Dataset): + """ + Pascal dataset + """ + + def __init__(self, + cihp_dir=Path.db_root_dir('cihp'), + split='train', + transform=None, + flip=False, + pascal_dir = PP.db_root_dir('pascal'), + atr_dir = PA.db_root_dir('atr'), + ): + """ + :param cihp_dir: path to CIHP dataset directory + :param pascal_dir: path to PASCAL dataset directory + :param atr_dir: path to ATR dataset directory + :param split: train/val + :param transform: transform to apply + """ + super(VOCSegmentation).__init__() + ## for cihp + self._flip_flag = flip + self._base_dir = cihp_dir + self._image_dir = os.path.join(self._base_dir, 'Images') + self._cat_dir = os.path.join(self._base_dir, 'Category_ids') + self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids') + ## for Pascal + self._base_dir_pascal = pascal_dir + self._image_dir_pascal = os.path.join(self._base_dir_pascal, 'JPEGImages') + self._cat_dir_pascal = os.path.join(self._base_dir_pascal, 'SegmentationPart') + # self._flip_dir_pascal = os.path.join(self._base_dir_pascal, 'Category_rev_ids') + ## for atr + self._base_dir_atr = atr_dir + self._image_dir_atr = os.path.join(self._base_dir_atr, 'JPEGImages') + self._cat_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug') + self._flip_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug_rev') + + if isinstance(split, str): + self.split = [split] + else: + split.sort() + self.split = split + + self.transform = transform + + _splits_dir = os.path.join(self._base_dir, 'lists') + _splits_dir_pascal = os.path.join(self._base_dir_pascal, 'list') + _splits_dir_atr = os.path.join(self._base_dir_atr, 'list') + + self.im_ids = [] + self.images = [] + self.categories = [] + self.flip_categories = [] + self.datasets_lbl = [] + + # num + self.num_cihp = 0 + self.num_pascal = 0 + self.num_atr = 0 + # for cihp is 0 + for splt in self.split: + with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + self.num_cihp += len(lines) + for ii, line in enumerate(lines): + + _image = os.path.join(self._image_dir, line+'.jpg' ) + _cat = os.path.join(self._cat_dir, line +'.png') + _flip = os.path.join(self._flip_dir,line + '.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + assert os.path.isfile(_flip) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + self.flip_categories.append(_flip) + self.datasets_lbl.append(0) + + # for pascal is 1 + for splt in self.split: + if splt == 'test': + splt='val' + with open(os.path.join(os.path.join(_splits_dir_pascal, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + self.num_pascal += len(lines) + for ii, line in enumerate(lines): + + _image = os.path.join(self._image_dir_pascal, line+'.jpg' ) + _cat = os.path.join(self._cat_dir_pascal, line +'.png') + # _flip = os.path.join(self._flip_dir,line + '.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + # assert os.path.isfile(_flip) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + self.flip_categories.append([]) + self.datasets_lbl.append(1) + + # for atr is 2 + for splt in self.split: + with open(os.path.join(os.path.join(_splits_dir_atr, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + self.num_atr += len(lines) + for ii, line in enumerate(lines): + _image = os.path.join(self._image_dir_atr, line + '.jpg') + _cat = os.path.join(self._cat_dir_atr, line + '.png') + _flip = os.path.join(self._flip_dir_atr, line + '.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + assert os.path.isfile(_flip) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + self.flip_categories.append(_flip) + self.datasets_lbl.append(2) + + assert (len(self.images) == len(self.categories)) + # assert len(self.flip_categories) == len(self.categories) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + def get_class_num(self): + return self.num_cihp,self.num_pascal,self.num_atr + + + + def __getitem__(self, index): + _img, _target,_lbl= self._make_img_gt_point_pair(index) + sample = {'image': _img, 'label': _target,} + + if self.transform is not None: + sample = self.transform(sample) + sample['pascal'] = _lbl + return sample + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) + # _target = np.array(Image.open(self.categories[index])).astype(np.float32) + + _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic + type_lbl = self.datasets_lbl[index] + if self._flip_flag: + if random.random() < 0.5 : + # _target = Image.open(self.flip_categories[index]) + _img = _img.transpose(Image.FLIP_LEFT_RIGHT) + if type_lbl == 0 or type_lbl == 2: + _target = Image.open(self.flip_categories[index]) + else: + _target = Image.open(self.categories[index]) + _target = _target.transpose(Image.FLIP_LEFT_RIGHT) + else: + _target = Image.open(self.categories[index]) + else: + _target = Image.open(self.categories[index]) + + return _img, _target,type_lbl + + def __str__(self): + return 'datasets(split=' + str(self.split) + ')' + + + + + + + + + + + + +if __name__ == '__main__': + from dataloaders import custom_transforms as tr + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + from torchvision import transforms + import matplotlib.pyplot as plt + + composed_transforms_tr = transforms.Compose([ + # tr.RandomHorizontalFlip(), + tr.RandomSized_new(512), + tr.RandomRotate(15), + tr.ToTensor_()]) + + + + voc_train = VOCSegmentation(split='train', + transform=composed_transforms_tr) + + dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1) + + for ii, sample in enumerate(dataloader): + if ii >10: + break \ No newline at end of file diff --git a/dataloaders/custom_transforms.py b/dataloaders/custom_transforms.py new file mode 100644 index 0000000..1556a6f --- /dev/null +++ b/dataloaders/custom_transforms.py @@ -0,0 +1,491 @@ +import torch +import math +import numbers +import random +import numpy as np + +from PIL import Image, ImageOps +from torchvision import transforms + +class RandomCrop(object): + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size # h, w + self.padding = padding + + def __call__(self, sample): + img, mask = sample['image'], sample['label'] + + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + mask = ImageOps.expand(mask, border=self.padding, fill=0) + + assert img.size == mask.size + w, h = img.size + th, tw = self.size # target size + if w == tw and h == th: + return {'image': img, + 'label': mask} + if w < tw or h < th: + img = img.resize((tw, th), Image.BILINEAR) + mask = mask.resize((tw, th), Image.NEAREST) + return {'image': img, + 'label': mask} + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + img = img.crop((x1, y1, x1 + tw, y1 + th)) + mask = mask.crop((x1, y1, x1 + tw, y1 + th)) + + return {'image': img, + 'label': mask} + +class RandomCrop_new(object): + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size # h, w + self.padding = padding + + def __call__(self, sample): + img, mask = sample['image'], sample['label'] + + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + mask = ImageOps.expand(mask, border=self.padding, fill=0) + + assert img.size == mask.size + w, h = img.size + th, tw = self.size # target size + if w == tw and h == th: + return {'image': img, + 'label': mask} + + new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255 + new_mask = Image.new('L',(tw,th),'white') # same above + + # if w > tw or h > th + x1 = y1 = 0 + if w > tw: + x1 = random.randint(0,w - tw) + if h > th: + y1 = random.randint(0,h - th) + # crop + img = img.crop((x1,y1, x1 + tw, y1 + th)) + mask = mask.crop((x1,y1, x1 + tw, y1 + th)) + new_img.paste(img,(0,0)) + new_mask.paste(mask,(0,0)) + + # x1 = random.randint(0, w - tw) + # y1 = random.randint(0, h - th) + # img = img.crop((x1, y1, x1 + tw, y1 + th)) + # mask = mask.crop((x1, y1, x1 + tw, y1 + th)) + + return {'image': new_img, + 'label': new_mask} + +class Paste(object): + def __init__(self, size,): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size # h, w + + def __call__(self, sample): + img, mask = sample['image'], sample['label'] + + assert img.size == mask.size + w, h = img.size + th, tw = self.size # target size + assert (w <=tw) and (h <= th) + if w == tw and h == th: + return {'image': img, + 'label': mask} + + new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255 + new_mask = Image.new('L',(tw,th),'white') # same above + + new_img.paste(img,(0,0)) + new_mask.paste(mask,(0,0)) + + return {'image': new_img, + 'label': new_mask} + +class CenterCrop(object): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + img = img.crop((x1, y1, x1 + tw, y1 + th)) + mask = mask.crop((x1, y1, x1 + tw, y1 + th)) + + return {'image': img, + 'label': mask} + +class RandomHorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'label': mask} + +class HorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'label': mask} + +class HorizontalFlip_only_img(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + img = img.transpose(Image.FLIP_LEFT_RIGHT) + # mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'label': mask} + +class RandomHorizontalFlip_cihp(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + # mask = Image.open() + + return {'image': img, + 'label': mask} + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = np.array(sample['image']).astype(np.float32) + mask = np.array(sample['label']).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + return {'image': img, + 'label': mask} + +class Normalize_255(object): + """Normalize a tensor image with mean and standard deviation. tf use 255. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(123.15, 115.90, 103.06), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = np.array(sample['image']).astype(np.float32) + mask = np.array(sample['label']).astype(np.float32) + # img = 255.0 + img -= self.mean + img /= self.std + img = img + img = img[[0,3,2,1],...] + return {'image': img, + 'label': mask} + +class Normalize_xception_tf(object): + # def __init__(self): + # self.rgb2bgr = + + def __call__(self, sample): + img = np.array(sample['image']).astype(np.float32) + mask = np.array(sample['label']).astype(np.float32) + img = (img*2.0)/255.0 - 1 + # print(img.shape) + # img = img[[0,3,2,1],...] + return {'image': img, + 'label': mask} + +class Normalize_xception_tf_only_img(object): + # def __init__(self): + # self.rgb2bgr = + + def __call__(self, sample): + img = np.array(sample['image']).astype(np.float32) + # mask = np.array(sample['label']).astype(np.float32) + img = (img*2.0)/255.0 - 1 + # print(img.shape) + # img = img[[0,3,2,1],...] + return {'image': img, + 'label': sample['label']} + +class Normalize_cityscapes(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.)): + self.mean = mean + + def __call__(self, sample): + img = np.array(sample['image']).astype(np.float32) + mask = np.array(sample['label']).astype(np.float32) + img -= self.mean + img /= 255.0 + + return {'image': img, + 'label': mask} + +class ToTensor_(object): + """Convert ndarrays in sample to Tensors.""" + def __init__(self): + self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...]) + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) + mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1)) + # mask[mask == 255] = 0 + + img = torch.from_numpy(img).float() + img = self.rgb2bgr(img) + mask = torch.from_numpy(mask).float() + + + return {'image': img, + 'label': mask} + +class ToTensor_only_img(object): + """Convert ndarrays in sample to Tensors.""" + def __init__(self): + self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...]) + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) + # mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1)) + # mask[mask == 255] = 0 + + img = torch.from_numpy(img).float() + img = self.rgb2bgr(img) + # mask = torch.from_numpy(mask).float() + + + return {'image': img, + 'label': sample['label']} + +class FixedResize(object): + def __init__(self, size): + self.size = tuple(reversed(size)) # size: (h, w) + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + + assert img.size == mask.size + + img = img.resize(self.size, Image.BILINEAR) + mask = mask.resize(self.size, Image.NEAREST) + + return {'image': img, + 'label': mask} + +class Keep_origin_size_Resize(object): + def __init__(self, max_size, scale=1.0): + self.size = tuple(reversed(max_size)) # size: (h, w) + self.scale = scale + self.paste = Paste(int(max_size[0]*scale)) + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + + assert img.size == mask.size + h, w = self.size + h = int(h*self.scale) + w = int(w*self.scale) + img = img.resize((h, w), Image.BILINEAR) + mask = mask.resize((h, w), Image.NEAREST) + + return self.paste({'image': img, + 'label': mask}) + +class Scale(object): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + w, h = img.size + + if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]): + return {'image': img, + 'label': mask} + oh, ow = self.size + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + + return {'image': img, + 'label': mask} + +class Scale_(object): + def __init__(self, scale): + self.scale = scale + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + w, h = img.size + ow = int(w*self.scale) + oh = int(h*self.scale) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + + return {'image': img, + 'label': mask} + +class Scale_only_img(object): + def __init__(self, scale): + self.scale = scale + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + # assert img.size == mask.size + w, h = img.size + ow = int(w*self.scale) + oh = int(h*self.scale) + img = img.resize((ow, oh), Image.BILINEAR) + # mask = mask.resize((ow, oh), Image.NEAREST) + + return {'image': img, + 'label': mask} + +class RandomSizedCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + for attempt in range(10): + area = img.size[0] * img.size[1] + target_area = random.uniform(0.45, 1.0) * area + aspect_ratio = random.uniform(0.5, 2) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img.size[0] and h <= img.size[1]: + x1 = random.randint(0, img.size[0] - w) + y1 = random.randint(0, img.size[1] - h) + + img = img.crop((x1, y1, x1 + w, y1 + h)) + mask = mask.crop((x1, y1, x1 + w, y1 + h)) + assert (img.size == (w, h)) + + img = img.resize((self.size, self.size), Image.BILINEAR) + mask = mask.resize((self.size, self.size), Image.NEAREST) + + return {'image': img, + 'label': mask} + + # Fallback + scale = Scale(self.size) + crop = CenterCrop(self.size) + sample = crop(scale(sample)) + return sample + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + rotate_degree = random.random() * 2 * self.degree - self.degree + img = img.rotate(rotate_degree, Image.BILINEAR) + mask = mask.rotate(rotate_degree, Image.NEAREST) + + return {'image': img, + 'label': mask} + +class RandomSized_new(object): + '''what we use is this class to aug''' + def __init__(self, size,scale1=0.5,scale2=2): + self.size = size + # self.scale = Scale(self.size) + self.crop = RandomCrop_new(self.size) + self.small_scale = scale1 + self.big_scale = scale2 + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + + w = int(random.uniform(self.small_scale, self.big_scale) * img.size[0]) + h = int(random.uniform(self.small_scale, self.big_scale) * img.size[1]) + + img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) + sample = {'image': img, 'label': mask} + # finish resize + return self.crop(sample) +# class Random + +class RandomScale(object): + def __init__(self, limit): + self.limit = limit + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + assert img.size == mask.size + + scale = random.uniform(self.limit[0], self.limit[1]) + w = int(scale * img.size[0]) + h = int(scale * img.size[1]) + + img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) + + return {'image': img, 'label': mask} \ No newline at end of file diff --git a/dataloaders/mypath_atr.py b/dataloaders/mypath_atr.py new file mode 100644 index 0000000..1701e5b --- /dev/null +++ b/dataloaders/mypath_atr.py @@ -0,0 +1,8 @@ +class Path(object): + @staticmethod + def db_root_dir(database): + if database == 'atr': + return './data/datasets/ATR/' # folder that contains atr/. + else: + print('Database {} not available.'.format(database)) + raise NotImplementedError diff --git a/dataloaders/mypath_cihp.py b/dataloaders/mypath_cihp.py new file mode 100644 index 0000000..02760eb --- /dev/null +++ b/dataloaders/mypath_cihp.py @@ -0,0 +1,8 @@ +class Path(object): + @staticmethod + def db_root_dir(database): + if database == 'cihp': + return './data/datasets/CIHP_4w/' + else: + print('Database {} not available.'.format(database)) + raise NotImplementedError diff --git a/dataloaders/mypath_pascal.py b/dataloaders/mypath_pascal.py new file mode 100644 index 0000000..aec4735 --- /dev/null +++ b/dataloaders/mypath_pascal.py @@ -0,0 +1,8 @@ +class Path(object): + @staticmethod + def db_root_dir(database): + if database == 'pascal': + return './data/datasets/pascal/' # folder that contains pascal/. + else: + print('Database {} not available.'.format(database)) + raise NotImplementedError diff --git a/dataloaders/pascal.py b/dataloaders/pascal.py new file mode 100644 index 0000000..bffd1ca --- /dev/null +++ b/dataloaders/pascal.py @@ -0,0 +1,106 @@ +from __future__ import print_function, division +import os +from PIL import Image +from torch.utils.data import Dataset +from .mypath_pascal import Path + +class VOCSegmentation(Dataset): + """ + Pascal dataset + """ + + def __init__(self, + base_dir=Path.db_root_dir('pascal'), + split='train', + transform=None + ): + """ + :param base_dir: path to PASCAL dataset directory + :param split: train/val + :param transform: transform to apply + """ + super(VOCSegmentation).__init__() + self._base_dir = base_dir + self._image_dir = os.path.join(self._base_dir, 'JPEGImages') + self._cat_dir = os.path.join(self._base_dir, 'SegmentationPart') + + if isinstance(split, str): + self.split = [split] + else: + split.sort() + self.split = split + + self.transform = transform + + _splits_dir = os.path.join(self._base_dir, 'list') + + self.im_ids = [] + self.images = [] + self.categories = [] + + for splt in self.split: + with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: + lines = f.read().splitlines() + + for ii, line in enumerate(lines): + + _image = os.path.join(self._image_dir, line+'.jpg' ) + _cat = os.path.join(self._cat_dir, line +'.png') + # print(self._image_dir,_image) + assert os.path.isfile(_image) + # print(_cat) + assert os.path.isfile(_cat) + self.im_ids.append(line) + self.images.append(_image) + self.categories.append(_cat) + + assert (len(self.images) == len(self.categories)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + _img, _target= self._make_img_gt_point_pair(index) + sample = {'image': _img, 'label': _target} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) + # _target = np.array(Image.open(self.categories[index])).astype(np.float32) + + _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic + _target = Image.open(self.categories[index]) + + return _img, _target + + def __str__(self): + return 'PASCAL(split=' + str(self.split) + ')' + +class test_segmentation(VOCSegmentation): + def __init__(self,base_dir=Path.db_root_dir('pascal'), + split='train', + transform=None, + flip=True): + super(test_segmentation, self).__init__(base_dir=base_dir,split=split,transform=transform) + self._flip_flag = flip + + def __getitem__(self, index): + _img, _target= self._make_img_gt_point_pair(index) + sample = {'image': _img, 'label': _target} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + diff --git a/eval_cihp.sh b/eval_cihp.sh new file mode 100644 index 0000000..4f72122 --- /dev/null +++ b/eval_cihp.sh @@ -0,0 +1,6 @@ +python ./exp/test/eval_show_pascal2cihp.py \ + --batch 1 --gpus 1 --classes 20 \ + --gt_path './data/datasets/CIHP_4w/Category_ids/' \ + --txt_file './data/datasets/CIHP_4w/lists/test_id.txt' \ + --loadmodel './data/pretrained_model/inference.pth' + diff --git a/eval_pascal.sh b/eval_pascal.sh new file mode 100644 index 0000000..dc694d4 --- /dev/null +++ b/eval_pascal.sh @@ -0,0 +1,6 @@ +python ./exp/test/eval_show_cihp2pascal.py \ + --batch 1 --gpus 1 --classes 20 \ + --gt_path './data/datasets/CIHP_4w/Category_ids/' \ + --txt_file './data/datasets/CIHP_4w/lists/test_id.txt' \ + --loadmodel './data/pretrained_model/cihp2pascal.pth' + diff --git a/exp/inference/inference.py b/exp/inference/inference.py new file mode 100644 index 0000000..adf4113 --- /dev/null +++ b/exp/inference/inference.py @@ -0,0 +1,203 @@ +import socket +import timeit +import numpy as np +from PIL import Image +from datetime import datetime +import os +import sys +from collections import OrderedDict +sys.path.append('../../') +# PyTorch includes +import torch +from torch.autograd import Variable +from torchvision import transforms +import cv2 + + +# Custom includes +from networks import deeplab_xception_transfer, graph +from dataloaders import custom_transforms as tr + +# +import argparse +import torch.nn.functional as F + +label_colours = [(0,0,0) + , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0) + , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] + + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +def flip_cihp(tail_list): + ''' + + :param tail_list: tail_list size is 1 x n_class x h x w + :return: + ''' + # tail_list = tail_list[0] + tail_list_rev = [None] * 20 + for xx in range(14): + tail_list_rev[xx] = tail_list[xx].unsqueeze(0) + tail_list_rev[14] = tail_list[15].unsqueeze(0) + tail_list_rev[15] = tail_list[14].unsqueeze(0) + tail_list_rev[16] = tail_list[17].unsqueeze(0) + tail_list_rev[17] = tail_list[16].unsqueeze(0) + tail_list_rev[18] = tail_list[19].unsqueeze(0) + tail_list_rev[19] = tail_list[18].unsqueeze(0) + return torch.cat(tail_list_rev,dim=0) + + +def decode_labels(mask, num_images=1, num_classes=20): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + n, h, w = mask.shape + assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % ( + n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) + pixels = img.load() + for j_, j in enumerate(mask[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_, j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def read_img(img_path): + _img = Image.open(img_path).convert('RGB') # return is RGB pic + return _img + +def img_transform(img, transform=None): + sample = {'image': img, 'label': 0} + + sample = transform(sample) + return sample + +def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True): + ''' + + :param net: + :param img_path: + :param output_path: + :return: + ''' + # adj + adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3) + + adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) + adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda() + + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj3_ = Variable(torch.from_numpy(cihp_adj).float()) + adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda() + + # multi-scale + scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75] + img = read_img(img_path) + testloader_list = [] + testloader_flip_list = [] + for pv in scale_list: + composed_transforms_ts = transforms.Compose([ + tr.Scale_only_img(pv), + tr.Normalize_xception_tf_only_img(), + tr.ToTensor_only_img()]) + + composed_transforms_ts_flip = transforms.Compose([ + tr.Scale_only_img(pv), + tr.HorizontalFlip_only_img(), + tr.Normalize_xception_tf_only_img(), + tr.ToTensor_only_img()]) + + testloader_list.append(img_transform(img, composed_transforms_ts)) + # print(img_transform(img, composed_transforms_ts)) + testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip)) + # print(testloader_list) + start_time = timeit.default_timer() + # One testing epoch + net.eval() + # 1 0.5 0.75 1.25 1.5 1.75 ; flip: + + for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)): + inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] + inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] + inputs = inputs.unsqueeze(0) + inputs_f = inputs_f.unsqueeze(0) + inputs = torch.cat((inputs, inputs_f), dim=0) + if iii == 0: + _, _, h, w = inputs.size() + # assert inputs.size() == inputs_f.size() + + # Forward pass of the mini-batch + inputs = Variable(inputs, requires_grad=False) + + with torch.no_grad(): + if use_gpu >= 0: + inputs = inputs.cuda() + # outputs = net.forward(inputs) + outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) + outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + + if iii > 0: + outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True) + outputs_final = outputs_final + outputs + else: + outputs_final = outputs.clone() + ################ plot pic + predictions = torch.max(outputs_final, 1)[1] + results = predictions.cpu().numpy() + vis_res = decode_labels(results) + + parsing_im = Image.fromarray(vis_res[0]) + parsing_im.save(output_path+'/{}.png'.format(output_name)) + cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :]) + + end_time = timeit.default_timer() + print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time)) + +if __name__ == '__main__': + '''argparse begin''' + parser = argparse.ArgumentParser() + # parser.add_argument('--loadmodel',default=None,type=str) + parser.add_argument('--loadmodel', default='', type=str) + parser.add_argument('--img_path', default='', type=str) + parser.add_argument('--output_path', default='', type=str) + parser.add_argument('--output_name', default='', type=str) + parser.add_argument('--use_gpu', default=1, type=int) + opts = parser.parse_args() + + net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20, + hidden_layers=128, + source_classes=7, ) + if not opts.loadmodel == '': + x = torch.load(opts.loadmodel) + net.load_source_model(x) + print('load model:', opts.loadmodel) + else: + print('no model load !!!!!!!!') + raise RuntimeError('No model!!!!') + + if opts.use_gpu >0 : + net.cuda() + use_gpu = True + else: + use_gpu = False + raise RuntimeError('must use the gpu!!!!') + + inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu) + diff --git a/exp/test/__init__.py b/exp/test/__init__.py new file mode 100644 index 0000000..a09a463 --- /dev/null +++ b/exp/test/__init__.py @@ -0,0 +1,3 @@ +from .test_from_disk import eval_ + +__all__ = ['eval_'] \ No newline at end of file diff --git a/exp/test/eval_show_cihp2pascal.py b/exp/test/eval_show_cihp2pascal.py new file mode 100644 index 0000000..d38bbb2 --- /dev/null +++ b/exp/test/eval_show_cihp2pascal.py @@ -0,0 +1,268 @@ +import socket +import timeit +import numpy as np +from PIL import Image +from datetime import datetime +import os +import sys +import glob +from collections import OrderedDict +sys.path.append('../../') +# PyTorch includes +import torch +import pdb +from torch.autograd import Variable +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import cv2 + +# Tensorboard include +# from tensorboardX import SummaryWriter + +# Custom includes +from dataloaders import pascal +from utils import util +from networks import deeplab_xception_transfer, graph +from dataloaders import custom_transforms as tr + +# +import argparse +import copy +import torch.nn.functional as F +from test_from_disk import eval_ + + +gpu_id = 1 + +label_colours = [(0,0,0) + # 0=background + ,(128,0,0), (0,128,0), (128,128,0), (0,0,128), (128,0,128), (0,128,128)] + + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +# def flip_cihp(tail_list): +# ''' +# +# :param tail_list: tail_list size is 1 x n_class x h x w +# :return: +# ''' +# # tail_list = tail_list[0] +# tail_list_rev = [None] * 20 +# for xx in range(14): +# tail_list_rev[xx] = tail_list[xx].unsqueeze(0) +# tail_list_rev[14] = tail_list[15].unsqueeze(0) +# tail_list_rev[15] = tail_list[14].unsqueeze(0) +# tail_list_rev[16] = tail_list[17].unsqueeze(0) +# tail_list_rev[17] = tail_list[16].unsqueeze(0) +# tail_list_rev[18] = tail_list[19].unsqueeze(0) +# tail_list_rev[19] = tail_list[18].unsqueeze(0) +# return torch.cat(tail_list_rev,dim=0) + +def decode_labels(mask, num_images=1, num_classes=20): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + n, h, w = mask.shape + assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) + pixels = img.load() + for j_, j in enumerate(mask[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_,j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def get_parser(): + '''argparse begin''' + parser = argparse.ArgumentParser() + LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--lr', default=1e-7, type=float) + parser.add_argument('--numworker', default=12, type=int) + parser.add_argument('--step', default=30, type=int) + # parser.add_argument('--loadmodel',default=None,type=str) + parser.add_argument('--classes', default=7, type=int) + parser.add_argument('--testepoch', default=10, type=int) + parser.add_argument('--loadmodel', default='', type=str) + parser.add_argument('--txt_file', default='', type=str) + parser.add_argument('--hidden_layers', default=128, type=int) + parser.add_argument('--gpus', default=4, type=int) + parser.add_argument('--output_path', default='./results/', type=str) + parser.add_argument('--gt_path', default='./results/', type=str) + opts = parser.parse_args() + return opts + + +def main(opts): + adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda() + + adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) + adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda() + + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj3_ = Variable(torch.from_numpy(cihp_adj).float()) + adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda() + + p = OrderedDict() # Parameters to include in report + p['trainBatch'] = opts.batch # Training batch size + p['nAveGrad'] = 1 # Average the gradient of several iterations + p['lr'] = opts.lr # Learning rate + p['lrFtr'] = 1e-5 + p['lraspp'] = 1e-5 + p['lrpro'] = 1e-5 + p['lrdecoder'] = 1e-5 + p['lrother'] = 1e-5 + p['wd'] = 5e-4 # Weight decay + p['momentum'] = 0.9 # Momentum + p['epoch_size'] = 10 # How many epochs to change learning rate + p['num_workers'] = opts.numworker + backbone = 'xception' # Use xception or resnet as feature extractor, + + with open(opts.txt_file, 'r') as f: + img_list = f.readlines() + + max_id = 0 + save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) + exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] + runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*')) + for r in runs: + run_id = int(r.split('_')[-1]) + if run_id >= max_id: + max_id = run_id + 1 + # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 + + # Network definition + if backbone == 'xception': + net = deeplab_xception_transfer.deeplab_xception_transfer_projection(n_classes=opts.classes, os=16, + hidden_layers=opts.hidden_layers, source_classes=20, + ) + elif backbone == 'resnet': + # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True) + raise NotImplementedError + else: + raise NotImplementedError + + if gpu_id >= 0: + net.cuda() + + # net load weights + if not opts.loadmodel =='': + x = torch.load(opts.loadmodel) + net.load_source_model(x) + print('load model:' ,opts.loadmodel) + else: + print('no model load !!!!!!!!') + + ## multi scale + scale_list=[1,0.5,0.75,1.25,1.5,1.75] + testloader_list = [] + testloader_flip_list = [] + for pv in scale_list: + composed_transforms_ts = transforms.Compose([ + tr.Scale_(pv), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts_flip = transforms.Compose([ + tr.Scale_(pv), + tr.HorizontalFlip(), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts) + voc_val_f = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip) + + testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers']) + testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers']) + + testloader_list.append(copy.deepcopy(testloader)) + testloader_flip_list.append(copy.deepcopy(testloader_flip)) + + print("Eval Network") + + if not os.path.exists(opts.output_path + 'pascal_output_vis/'): + os.makedirs(opts.output_path + 'pascal_output_vis/') + if not os.path.exists(opts.output_path + 'pascal_output/'): + os.makedirs(opts.output_path + 'pascal_output/') + + start_time = timeit.default_timer() + # One testing epoch + total_iou = 0.0 + net.eval() + for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)): + print(ii) + #1 0.5 0.75 1.25 1.5 1.75 ; flip: + sample1 = large_sample_batched[:6] + sample2 = large_sample_batched[6:] + for iii,sample_batched in enumerate(zip(sample1,sample2)): + inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] + inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] + inputs = torch.cat((inputs,inputs_f),dim=0) + if iii == 0: + _,_,h,w = inputs.size() + # assert inputs.size() == inputs_f.size() + + # Forward pass of the mini-batch + inputs, labels = Variable(inputs, requires_grad=False), Variable(labels) + + with torch.no_grad(): + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + # outputs = net.forward(inputs) + # pdb.set_trace() + outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) + outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + + if iii>0: + outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True) + outputs_final = outputs_final + outputs + else: + outputs_final = outputs.clone() + ################ plot pic + predictions = torch.max(outputs_final, 1)[1] + prob_predictions = torch.max(outputs_final,1)[0] + results = predictions.cpu().numpy() + prob_results = prob_predictions.cpu().numpy() + vis_res = decode_labels(results) + + parsing_im = Image.fromarray(vis_res[0]) + parsing_im.save(opts.output_path + 'pascal_output_vis/{}.png'.format(img_list[ii][:-1])) + cv2.imwrite(opts.output_path + 'pascal_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:]) + # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :]) + # pred_list.append(predictions.cpu()) + # label_list.append(labels.squeeze(1).cpu()) + # loss = criterion(outputs, labels, batch_average=True) + # running_loss_ts += loss.item() + + # total_iou += utils.get_iou(predictions, labels) + end_time = timeit.default_timer() + print('time use for '+str(ii) + ' is :' + str(end_time - start_time)) + + # Eval + pred_path = opts.output_path + 'pascal_output/' + eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file) + +if __name__ == '__main__': + opts = get_parser() + main(opts) \ No newline at end of file diff --git a/exp/test/eval_show_pascal2cihp.py b/exp/test/eval_show_pascal2cihp.py new file mode 100644 index 0000000..cc30173 --- /dev/null +++ b/exp/test/eval_show_pascal2cihp.py @@ -0,0 +1,268 @@ +import socket +import timeit +import numpy as np +from PIL import Image +from datetime import datetime +import os +import sys +import glob +from collections import OrderedDict +sys.path.append('../../') +# PyTorch includes +import torch +import pdb +from torch.autograd import Variable +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import cv2 + +# Tensorboard include +# from tensorboardX import SummaryWriter + +# Custom includes +from dataloaders import cihp +from utils import util +from networks import deeplab_xception_transfer, graph +from dataloaders import custom_transforms as tr + +# +import argparse +import copy +import torch.nn.functional as F +from test_from_disk import eval_ + + +gpu_id = 1 + +label_colours = [(0,0,0) + , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0) + , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] + + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +def flip_cihp(tail_list): + ''' + + :param tail_list: tail_list size is 1 x n_class x h x w + :return: + ''' + # tail_list = tail_list[0] + tail_list_rev = [None] * 20 + for xx in range(14): + tail_list_rev[xx] = tail_list[xx].unsqueeze(0) + tail_list_rev[14] = tail_list[15].unsqueeze(0) + tail_list_rev[15] = tail_list[14].unsqueeze(0) + tail_list_rev[16] = tail_list[17].unsqueeze(0) + tail_list_rev[17] = tail_list[16].unsqueeze(0) + tail_list_rev[18] = tail_list[19].unsqueeze(0) + tail_list_rev[19] = tail_list[18].unsqueeze(0) + return torch.cat(tail_list_rev,dim=0) + +def decode_labels(mask, num_images=1, num_classes=20): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + n, h, w = mask.shape + assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) + pixels = img.load() + for j_, j in enumerate(mask[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_,j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def get_parser(): + '''argparse begin''' + parser = argparse.ArgumentParser() + LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--lr', default=1e-7, type=float) + parser.add_argument('--numworker', default=12, type=int) + parser.add_argument('--step', default=30, type=int) + # parser.add_argument('--loadmodel',default=None,type=str) + parser.add_argument('--classes', default=7, type=int) + parser.add_argument('--testepoch', default=10, type=int) + parser.add_argument('--loadmodel', default='', type=str) + parser.add_argument('--txt_file', default='', type=str) + parser.add_argument('--hidden_layers', default=128, type=int) + parser.add_argument('--gpus', default=4, type=int) + parser.add_argument('--output_path', default='./results/', type=str) + parser.add_argument('--gt_path', default='./results/', type=str) + opts = parser.parse_args() + return opts + + +def main(opts): + adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3) + + adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) + adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda() + + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj3_ = Variable(torch.from_numpy(cihp_adj).float()) + adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda() + + p = OrderedDict() # Parameters to include in report + p['trainBatch'] = opts.batch # Training batch size + p['nAveGrad'] = 1 # Average the gradient of several iterations + p['lr'] = opts.lr # Learning rate + p['lrFtr'] = 1e-5 + p['lraspp'] = 1e-5 + p['lrpro'] = 1e-5 + p['lrdecoder'] = 1e-5 + p['lrother'] = 1e-5 + p['wd'] = 5e-4 # Weight decay + p['momentum'] = 0.9 # Momentum + p['epoch_size'] = 10 # How many epochs to change learning rate + p['num_workers'] = opts.numworker + backbone = 'xception' # Use xception or resnet as feature extractor, + + with open(opts.txt_file, 'r') as f: + img_list = f.readlines() + + max_id = 0 + save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) + exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] + runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*')) + for r in runs: + run_id = int(r.split('_')[-1]) + if run_id >= max_id: + max_id = run_id + 1 + # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 + + # Network definition + if backbone == 'xception': + net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16, + hidden_layers=opts.hidden_layers, source_classes=7, + ) + elif backbone == 'resnet': + # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True) + raise NotImplementedError + else: + raise NotImplementedError + + if gpu_id >= 0: + net.cuda() + + # net load weights + if not opts.loadmodel =='': + x = torch.load(opts.loadmodel) + net.load_source_model(x) + print('load model:' ,opts.loadmodel) + else: + print('no model load !!!!!!!!') + + ## multi scale + scale_list=[1,0.5,0.75,1.25,1.5,1.75] + testloader_list = [] + testloader_flip_list = [] + for pv in scale_list: + composed_transforms_ts = transforms.Compose([ + tr.Scale_(pv), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts_flip = transforms.Compose([ + tr.Scale_(pv), + tr.HorizontalFlip(), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + voc_val = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts) + voc_val_f = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts_flip) + + testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers']) + testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers']) + + testloader_list.append(copy.deepcopy(testloader)) + testloader_flip_list.append(copy.deepcopy(testloader_flip)) + + print("Eval Network") + + if not os.path.exists(opts.output_path + 'cihp_output_vis/'): + os.makedirs(opts.output_path + 'cihp_output_vis/') + if not os.path.exists(opts.output_path + 'cihp_output/'): + os.makedirs(opts.output_path + 'cihp_output/') + + start_time = timeit.default_timer() + # One testing epoch + total_iou = 0.0 + net.eval() + for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)): + print(ii) + #1 0.5 0.75 1.25 1.5 1.75 ; flip: + sample1 = large_sample_batched[:6] + sample2 = large_sample_batched[6:] + for iii,sample_batched in enumerate(zip(sample1,sample2)): + inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] + inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] + inputs = torch.cat((inputs,inputs_f),dim=0) + if iii == 0: + _,_,h,w = inputs.size() + # assert inputs.size() == inputs_f.size() + + # Forward pass of the mini-batch + inputs, labels = Variable(inputs, requires_grad=False), Variable(labels) + + with torch.no_grad(): + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + # outputs = net.forward(inputs) + # pdb.set_trace() + outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) + outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + + if iii>0: + outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True) + outputs_final = outputs_final + outputs + else: + outputs_final = outputs.clone() + ################ plot pic + predictions = torch.max(outputs_final, 1)[1] + prob_predictions = torch.max(outputs_final,1)[0] + results = predictions.cpu().numpy() + prob_results = prob_predictions.cpu().numpy() + vis_res = decode_labels(results) + + parsing_im = Image.fromarray(vis_res[0]) + parsing_im.save(opts.output_path + 'cihp_output_vis/{}.png'.format(img_list[ii][:-1])) + cv2.imwrite(opts.output_path + 'cihp_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:]) + # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :]) + # pred_list.append(predictions.cpu()) + # label_list.append(labels.squeeze(1).cpu()) + # loss = criterion(outputs, labels, batch_average=True) + # running_loss_ts += loss.item() + + # total_iou += utils.get_iou(predictions, labels) + end_time = timeit.default_timer() + print('time use for '+str(ii) + ' is :' + str(end_time - start_time)) + + # Eval + pred_path = opts.output_path + 'cihp_output/' + eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file) + +if __name__ == '__main__': + opts = get_parser() + main(opts) \ No newline at end of file diff --git a/exp/test/test_from_disk.py b/exp/test/test_from_disk.py new file mode 100644 index 0000000..2b72604 --- /dev/null +++ b/exp/test/test_from_disk.py @@ -0,0 +1,65 @@ +import sys +sys.path.append('./') +# PyTorch includes +import torch +import numpy as np + +from utils import test_human +from PIL import Image + +# +import argparse + +def get_parser(): + '''argparse begin''' + parser = argparse.ArgumentParser() + LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--lr', default=1e-7, type=float) + parser.add_argument('--numworker',default=12,type=int) + parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices) + parser.add_argument('--step', default=30, type=int) + parser.add_argument('--txt_file',default=None,type=str) + parser.add_argument('--pred_path',default=None,type=str) + parser.add_argument('--gt_path',default=None,type=str) + parser.add_argument('--classes', default=7, type=int) + parser.add_argument('--testepoch', default=10, type=int) + opts = parser.parse_args() + return opts + +def eval_(pred_path, gt_path, classes, txt_file): + pred_path = pred_path + gt_path = gt_path + + with open(txt_file,) as f: + lines = f.readlines() + lines = [x.strip() for x in lines] + + output_list = [] + label_list = [] + for i,file in enumerate(lines): + print(i) + file_name = file + '.png' + try: + predict_pic = np.array(Image.open(pred_path+file_name)) + gt_pic = np.array(Image.open(gt_path+file_name)) + output_list.append(torch.from_numpy(predict_pic)) + label_list.append(torch.from_numpy(gt_pic)) + except: + print(file_name,flush=True) + raise RuntimeError('no predict/gt image.') + # gt_pic = np.array(Image.open(gt_path + file_name)) + # output_list.append(torch.from_numpy(gt_pic)) + # label_list.append(torch.from_numpy(gt_pic)) + + + miou = test_human.get_iou_from_list(output_list, label_list, n_cls=classes) + + print('Validation:') + print('MIoU: %f\n' % miou) + +if __name__ == '__main__': + opts = get_parser() + eval_(pred_path=opts.pred_path, gt_path=opts.gt_path, classes=opts.classes, txt_file=opts.txt_file) \ No newline at end of file diff --git a/exp/transfer/train_cihp_from_pascal.py b/exp/transfer/train_cihp_from_pascal.py new file mode 100644 index 0000000..30e8df4 --- /dev/null +++ b/exp/transfer/train_cihp_from_pascal.py @@ -0,0 +1,331 @@ +import socket +import timeit +from datetime import datetime +import os +import sys +import glob +import numpy as np +from collections import OrderedDict +sys.path.append('../../') +sys.path.append('../../networks/') +# PyTorch includes +import torch +from torch.autograd import Variable +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.utils import make_grid + + +# Tensorboard include +from tensorboardX import SummaryWriter + +# Custom includes +from dataloaders import cihp +from utils import util,get_iou_from_list +from networks import deeplab_xception_transfer, graph +from dataloaders import custom_transforms as tr + +# +import argparse + +gpu_id = 0 + +nEpochs = 100 # Number of epochs for training +resume_epoch = 0 # Default is 0, change if want to resume + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +def flip_cihp(tail_list): + ''' + + :param tail_list: tail_list size is 1 x n_class x h x w + :return: + ''' + # tail_list = tail_list[0] + tail_list_rev = [None] * 20 + for xx in range(14): + tail_list_rev[xx] = tail_list[xx].unsqueeze(0) + tail_list_rev[14] = tail_list[15].unsqueeze(0) + tail_list_rev[15] = tail_list[14].unsqueeze(0) + tail_list_rev[16] = tail_list[17].unsqueeze(0) + tail_list_rev[17] = tail_list[16].unsqueeze(0) + tail_list_rev[18] = tail_list[19].unsqueeze(0) + tail_list_rev[19] = tail_list[18].unsqueeze(0) + return torch.cat(tail_list_rev,dim=0) + +def get_parser(): + '''argparse begin''' + parser = argparse.ArgumentParser() + LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--lr', default=1e-7, type=float) + parser.add_argument('--numworker',default=12,type=int) + parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices) + parser.add_argument('--step', default=10, type=int) + parser.add_argument('--classes', default=7, type=int) + parser.add_argument('--testInterval', default=10, type=int) + parser.add_argument('--loadmodel',default='',type=str) + parser.add_argument('--pretrainedModel', default='', type=str) + parser.add_argument('--hidden_layers',default=128,type=int) + parser.add_argument('--gpus',default=4, type=int) + + opts = parser.parse_args() + return opts + +def get_graphs(opts): + adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() + adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2, 3).cuda() + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2, 3) + + adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) + adj3 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda() + adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7) + + # adj2 = torch.from_numpy(graph.cihp2pascal_adj).float() + # adj2 = adj2.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20) + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj3_ = Variable(torch.from_numpy(cihp_adj).float()) + adj1 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda() + adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20) + train_graph = [adj1, adj2, adj3] + test_graph = [adj1_test, adj2_test, adj3_test] + return train_graph, test_graph + + +def val_cihp(net_, testloader, testloader_flip, test_graph, epoch, writer, criterion, classes=20): + adj1_test, adj2_test, adj3_test = test_graph + num_img_ts = len(testloader) + net_.eval() + pred_list = [] + label_list = [] + running_loss_ts = 0.0 + miou = 0 + for ii, sample_batched in enumerate(zip(testloader, testloader_flip)): + + inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] + inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] + inputs = torch.cat((inputs, inputs_f), dim=0) + # Forward pass of the mini-batch + inputs, labels = Variable(inputs, requires_grad=False), Variable(labels) + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + + with torch.no_grad(): + outputs = net_.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) + # pdb.set_trace() + outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + predictions = torch.max(outputs, 1)[1] + pred_list.append(predictions.cpu()) + label_list.append(labels.squeeze(1).cpu()) + loss = criterion(outputs, labels, batch_average=True) + running_loss_ts += loss.item() + # total_iou += utils.get_iou(predictions, labels) + # Print stuff + if ii % num_img_ts == num_img_ts - 1: + # if ii == 10: + miou = get_iou_from_list(pred_list, label_list, n_cls=classes) + running_loss_ts = running_loss_ts / num_img_ts + + print('Validation:') + print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0])) + writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch) + writer.add_scalar('data/test_miour', miou, epoch) + print('Loss: %f' % running_loss_ts) + print('MIoU: %f\n' % miou) + + +def main(opts): + p = OrderedDict() # Parameters to include in report + p['trainBatch'] = opts.batch # Training batch size + testBatch = 1 # Testing batch size + useTest = True # See evolution of the test set when training + nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs + snapshot = 1 # Store a model every snapshot epochs + p['nAveGrad'] = 1 # Average the gradient of several iterations + p['lr'] = opts.lr # Learning rate + p['lrFtr'] = 1e-5 + p['lraspp'] = 1e-5 + p['lrpro'] = 1e-5 + p['lrdecoder'] = 1e-5 + p['lrother'] = 1e-5 + p['wd'] = 5e-4 # Weight decay + p['momentum'] = 0.9 # Momentum + p['epoch_size'] = opts.step # How many epochs to change learning rate + p['num_workers'] = opts.numworker + model_path = opts.pretrainedModel + backbone = 'xception' # Use xception or resnet as feature extractor, + nEpochs = opts.epochs + + max_id = 0 + save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) + exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] + runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*')) + for r in runs: + run_id = int(r.split('_')[-1]) + if run_id >= max_id: + max_id = run_id + 1 + save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id)) + + # Network definition + if backbone == 'xception': + net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20, os=16, + hidden_layers=opts.hidden_layers, source_classes=7, ) + elif backbone == 'resnet': + # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True) + raise NotImplementedError + else: + raise NotImplementedError + + modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S') + criterion = util.cross_entropy2d + + if gpu_id >= 0: + # torch.cuda.set_device(device=gpu_id) + net_.cuda() + + # net load weights + if not model_path == '': + x = torch.load(model_path) + net_.load_state_dict_new(x) + print('load pretrainedModel.') + else: + print('no pretrainedModel.') + if not opts.loadmodel =='': + x = torch.load(opts.loadmodel) + net_.load_source_model(x) + print('load model:' ,opts.loadmodel) + else: + print('no model load !!!!!!!!') + + log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) + writer = SummaryWriter(log_dir=log_dir) + writer.add_text('load model',opts.loadmodel,1) + writer.add_text('setting',sys.argv[0],1) + + if opts.freezeBN: + net_.freeze_bn() + + # Use the following optimizer + optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) + + composed_transforms_tr = transforms.Compose([ + tr.RandomSized_new(512), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts = transforms.Compose([ + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts_flip = transforms.Compose([ + tr.HorizontalFlip(), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True) + voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts) + voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip) + + trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=p['num_workers'],drop_last=True) + testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers']) + testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers']) + + num_img_tr = len(trainloader) + num_img_ts = len(testloader) + running_loss_tr = 0.0 + running_loss_ts = 0.0 + aveGrad = 0 + global_step = 0 + print("Training Network") + + net = torch.nn.DataParallel(net_) + train_graph, test_graph = get_graphs(opts) + adj1, adj2, adj3 = train_graph + + + # Main Training and Testing Loop + for epoch in range(resume_epoch, nEpochs): + start_time = timeit.default_timer() + + if epoch % p['epoch_size'] == p['epoch_size'] - 1: + lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9) + optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd']) + writer.add_scalar('data/lr_', lr_, epoch) + print('(poly lr policy) learning rate: ', lr_) + + net.train() + for ii, sample_batched in enumerate(trainloader): + + inputs, labels = sample_batched['image'], sample_batched['label'] + # Forward-Backward of the mini-batch + inputs, labels = Variable(inputs, requires_grad=True), Variable(labels) + global_step += inputs.data.shape[0] + + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + + outputs = net.forward(inputs, adj1, adj3, adj2) + + loss = criterion(outputs, labels, batch_average=True) + running_loss_tr += loss.item() + + # Print stuff + if ii % num_img_tr == (num_img_tr - 1): + running_loss_tr = running_loss_tr / num_img_tr + writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0])) + print('Loss: %f' % running_loss_tr) + running_loss_tr = 0 + stop_time = timeit.default_timer() + print("Execution time: " + str(stop_time - start_time) + "\n") + + # Backward the averaged gradient + loss /= p['nAveGrad'] + loss.backward() + aveGrad += 1 + + # Update the weights once in p['nAveGrad'] forward passes + if aveGrad % p['nAveGrad'] == 0: + writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch) + optimizer.step() + optimizer.zero_grad() + aveGrad = 0 + + # Show 10 * 3 images results each epoch + if ii % (num_img_tr // 10) == 0: + grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True) + writer.add_image('Image', grid_image, global_step) + grid_image = make_grid(util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False, + range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) + print('loss is ', loss.cpu().item(), flush=True) + + # Save the model + if (epoch % snapshot) == snapshot - 1: + torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')) + print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))) + + torch.cuda.empty_cache() + + # One testing epoch + if useTest and epoch % nTestInterval == (nTestInterval - 1): + val_cihp(net_,testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph, + epoch=epoch,writer=writer,criterion=criterion) + torch.cuda.empty_cache() + + + + +if __name__ == '__main__': + opts = get_parser() + main(opts) \ No newline at end of file diff --git a/exp/universal/pascal_atr_cihp_uni.py b/exp/universal/pascal_atr_cihp_uni.py new file mode 100644 index 0000000..95057b8 --- /dev/null +++ b/exp/universal/pascal_atr_cihp_uni.py @@ -0,0 +1,493 @@ +import socket +import timeit +from datetime import datetime +import os +import sys +import glob +import numpy as np +from collections import OrderedDict +sys.path.append('./') +sys.path.append('./networks/') +# PyTorch includes +import torch +from torch.autograd import Variable +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import random + +# Tensorboard include +from tensorboardX import SummaryWriter + +# Custom includes +from dataloaders import pascal, cihp_pascal_atr +from utils import get_iou_from_list +from utils import util as ut +from networks import deeplab_xception_universal, graph +from dataloaders import custom_transforms as tr +from utils import sampler as sam +# +import argparse + +''' +source is cihp +target is pascal +''' + +gpu_id = 1 +# print('Using GPU: {} '.format(gpu_id)) + +# nEpochs = 100 # Number of epochs for training +resume_epoch = 0 # Default is 0, change if want to resume + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +def flip_cihp(tail_list): + ''' + + :param tail_list: tail_list size is 1 x n_class x h x w + :return: + ''' + # tail_list = tail_list[0] + tail_list_rev = [None] * 20 + for xx in range(14): + tail_list_rev[xx] = tail_list[xx].unsqueeze(0) + tail_list_rev[14] = tail_list[15].unsqueeze(0) + tail_list_rev[15] = tail_list[14].unsqueeze(0) + tail_list_rev[16] = tail_list[17].unsqueeze(0) + tail_list_rev[17] = tail_list[16].unsqueeze(0) + tail_list_rev[18] = tail_list[19].unsqueeze(0) + tail_list_rev[19] = tail_list[18].unsqueeze(0) + return torch.cat(tail_list_rev,dim=0) + +def get_parser(): + '''argparse begin''' + parser = argparse.ArgumentParser() + LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--lr', default=1e-7, type=float) + parser.add_argument('--numworker',default=12,type=int) + # parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices) + parser.add_argument('--step', default=10, type=int) + # parser.add_argument('--loadmodel',default=None,type=str) + parser.add_argument('--classes', default=7, type=int) + parser.add_argument('--testepoch', default=10, type=int) + parser.add_argument('--loadmodel',default='',type=str) + parser.add_argument('--pretrainedModel', default='', type=str) + parser.add_argument('--hidden_layers',default=128,type=int) + parser.add_argument('--gpus',default=4, type=int) + parser.add_argument('--testInterval', default=5, type=int) + opts = parser.parse_args() + return opts + +def get_graphs(opts): + '''source is pascal; target is cihp; middle is atr''' + # target 1 + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj1_ = Variable(torch.from_numpy(cihp_adj).float()) + adj1 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda() + adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20) + #source 2 + adj2_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) + adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda() + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7) + # s to target 3 + adj3_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() + adj3 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2,3).cuda() + adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2,3) + # middle 4 + atr_adj = graph.preprocess_adj(graph.atr_graph) + adj4_ = Variable(torch.from_numpy(atr_adj).float()) + adj4 = adj4_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 18, 18).cuda() + adj4_test = adj4_.unsqueeze(0).unsqueeze(0).expand(1, 1, 18, 18) + # source to middle 5 + adj5_ = torch.from_numpy(graph.pascal2atr_nlp_adj).float() + adj5 = adj5_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 18).cuda() + adj5_test = adj5_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 18) + # target to middle 6 + adj6_ = torch.from_numpy(graph.cihp2atr_nlp_adj).float() + adj6 = adj6_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 18).cuda() + adj6_test = adj6_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 18) + train_graph = [adj1, adj2, adj3, adj4, adj5, adj6] + test_graph = [adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test] + return train_graph, test_graph + + +def main(opts): + # Set parameters + p = OrderedDict() # Parameters to include in report + p['trainBatch'] = opts.batch # Training batch size + testBatch = 1 # Testing batch size + useTest = True # See evolution of the test set when training + nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs + snapshot = 1 # Store a model every snapshot epochs + p['nAveGrad'] = 1 # Average the gradient of several iterations + p['lr'] = opts.lr # Learning rate + p['wd'] = 5e-4 # Weight decay + p['momentum'] = 0.9 # Momentum + p['epoch_size'] = opts.step # How many epochs to change learning rate + p['num_workers'] = opts.numworker + model_path = opts.pretrainedModel + backbone = 'xception' # Use xception or resnet as feature extractor + nEpochs = opts.epochs + + max_id = 0 + save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) + exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] + runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*')) + for r in runs: + run_id = int(r.split('_')[-1]) + if run_id >= max_id: + max_id = run_id + 1 + # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 + save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(max_id)) + + # Network definition + if backbone == 'xception': + net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(n_classes=20, os=16, + hidden_layers=opts.hidden_layers, + source_classes=7, + middle_classes=18, ) + elif backbone == 'resnet': + # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True) + raise NotImplementedError + else: + raise NotImplementedError + + modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S') + criterion = ut.cross_entropy2d + + if gpu_id >= 0: + # torch.cuda.set_device(device=gpu_id) + net_.cuda() + + # net load weights + if not model_path == '': + x = torch.load(model_path) + net_.load_state_dict_new(x) + print('load pretrainedModel.') + else: + print('no pretrainedModel.') + + if not opts.loadmodel =='': + x = torch.load(opts.loadmodel) + net_.load_source_model(x) + print('load model:' ,opts.loadmodel) + else: + print('no trained model load !!!!!!!!') + + log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) + writer = SummaryWriter(log_dir=log_dir) + writer.add_text('load model',opts.loadmodel,1) + writer.add_text('setting',sys.argv[0],1) + + # Use the following optimizer + optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) + + composed_transforms_tr = transforms.Compose([ + tr.RandomSized_new(512), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts = transforms.Compose([ + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + composed_transforms_ts_flip = transforms.Compose([ + tr.HorizontalFlip(), + tr.Normalize_xception_tf(), + tr.ToTensor_()]) + + all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True) + voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts) + voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip) + + num_cihp,num_pascal,num_atr = all_train.get_class_num() + ss = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch) + # balance datasets based pascal + ss_balanced = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch, balance_id=1) + + trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'], + sampler=ss, drop_last=True) + trainloader_balanced = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'], + sampler=ss_balanced, drop_last=True) + testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers']) + testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers']) + + num_img_tr = len(trainloader) + num_img_balanced = len(trainloader_balanced) + num_img_ts = len(testloader) + running_loss_tr = 0.0 + running_loss_tr_atr = 0.0 + running_loss_ts = 0.0 + aveGrad = 0 + global_step = 0 + print("Training Network") + net = torch.nn.DataParallel(net_) + + id_list = torch.LongTensor(range(opts.batch)) + pascal_iter = int(num_img_tr//opts.batch) + + # Get graphs + train_graph, test_graph = get_graphs(opts) + adj1, adj2, adj3, adj4, adj5, adj6 = train_graph + adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph + + # Main Training and Testing Loop + for epoch in range(resume_epoch, int(1.5*nEpochs)): + start_time = timeit.default_timer() + + if epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch nEpochs: + lr_ = ut.lr_poly(p['lr'], epoch-nEpochs, int(0.5*nEpochs), 0.9) + optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd']) + print('(poly lr policy) learning rate: ', lr_) + writer.add_scalar('data/lr_', lr_, epoch) + + net_.train() + if epoch < nEpochs: + for ii, sample_batched in enumerate(trainloader): + inputs, labels = sample_batched['image'], sample_batched['label'] + dataset_lbl = sample_batched['pascal'][0].item() + # Forward-Backward of the mini-batch + inputs, labels = Variable(inputs, requires_grad=True), Variable(labels) + global_step += 1 + + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + + if dataset_lbl == 0: + # 0 is cihp -- target + _, outputs,_ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1, adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2,3), adj4_middle=adj4,adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3),adj5_transfer_m2s=adj5,adj6_transfer_m2t=adj6,) + elif dataset_lbl == 1: + # pascal is source + outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1, + adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3), + adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5, + adj6_transfer_m2t=adj6, ) + else: + # atr + _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1, + adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3), + adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5, + adj6_transfer_m2t=adj6, ) + # print(sample_batched['pascal']) + # print(outputs.size(),) + # print(labels) + loss = criterion(outputs, labels, batch_average=True) + running_loss_tr += loss.item() + + # Print stuff + if ii % num_img_tr == (num_img_tr - 1): + running_loss_tr = running_loss_tr / num_img_tr + writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, epoch)) + print('Loss: %f' % running_loss_tr) + running_loss_tr = 0 + stop_time = timeit.default_timer() + print("Execution time: " + str(stop_time - start_time) + "\n") + + # Backward the averaged gradient + loss /= p['nAveGrad'] + loss.backward() + aveGrad += 1 + + # Update the weights once in p['nAveGrad'] forward passes + if aveGrad % p['nAveGrad'] == 0: + writer.add_scalar('data/total_loss_iter', loss.item(), global_step) + if dataset_lbl == 0: + writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step) + if dataset_lbl == 1: + writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step) + if dataset_lbl == 2: + writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step) + optimizer.step() + optimizer.zero_grad() + # optimizer_gcn.step() + # optimizer_gcn.zero_grad() + aveGrad = 0 + + # Show 10 * 3 images results each epoch + if ii % (num_img_tr // 10) == 0: + grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True) + writer.add_image('Image', grid_image, global_step) + grid_image = make_grid(ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False, + range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) + + print('loss is ',loss.cpu().item(),flush=True) + else: + # Balanced the number of datasets + for ii, sample_batched in enumerate(trainloader_balanced): + inputs, labels = sample_batched['image'], sample_batched['label'] + dataset_lbl = sample_batched['pascal'][0].item() + # Forward-Backward of the mini-batch + inputs, labels = Variable(inputs, requires_grad=True), Variable(labels) + global_step += 1 + + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + + if dataset_lbl == 0: + # 0 is cihp -- target + _, outputs, _ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1, + adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3), + adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5, + adj6_transfer_m2t=adj6, ) + elif dataset_lbl == 1: + # pascal is source + outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1, + adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3), + adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5, + adj6_transfer_m2t=adj6, ) + else: + # atr + _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1, + adj2_source=adj2, + adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3), + adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3), + adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5, + adj6_transfer_m2t=adj6, ) + # print(sample_batched['pascal']) + # print(outputs.size(),) + # print(labels) + loss = criterion(outputs, labels, batch_average=True) + running_loss_tr += loss.item() + + # Print stuff + if ii % num_img_balanced == (num_img_balanced - 1): + running_loss_tr = running_loss_tr / num_img_balanced + writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, epoch)) + print('Loss: %f' % running_loss_tr) + running_loss_tr = 0 + stop_time = timeit.default_timer() + print("Execution time: " + str(stop_time - start_time) + "\n") + + # Backward the averaged gradient + loss /= p['nAveGrad'] + loss.backward() + aveGrad += 1 + + # Update the weights once in p['nAveGrad'] forward passes + if aveGrad % p['nAveGrad'] == 0: + writer.add_scalar('data/total_loss_iter', loss.item(), global_step) + if dataset_lbl == 0: + writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step) + if dataset_lbl == 1: + writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step) + if dataset_lbl == 2: + writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step) + optimizer.step() + optimizer.zero_grad() + + aveGrad = 0 + + # Show 10 * 3 images results each epoch + if ii % (num_img_balanced // 10) == 0: + grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True) + writer.add_image('Image', grid_image, global_step) + grid_image = make_grid( + ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, + normalize=False, + range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid( + ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, + normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) + + print('loss is ', loss.cpu().item(), flush=True) + + # Save the model + if (epoch % snapshot) == snapshot - 1: + torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')) + print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))) + + # One testing epoch + if useTest and epoch % nTestInterval == (nTestInterval - 1): + val_pascal(net_=net_, testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph, + criterion=criterion, epoch=epoch, writer=writer) + + +def val_pascal(net_, testloader, testloader_flip, test_graph, criterion, epoch, writer, classes=7): + running_loss_ts = 0.0 + miou = 0 + adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph + num_img_ts = len(testloader) + net_.eval() + pred_list = [] + label_list = [] + for ii, sample_batched in enumerate(zip(testloader, testloader_flip)): + # print(ii) + inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] + inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] + inputs = torch.cat((inputs, inputs_f), dim=0) + # Forward pass of the mini-batch + inputs, labels = Variable(inputs, requires_grad=False), Variable(labels) + + with torch.no_grad(): + if gpu_id >= 0: + inputs, labels = inputs.cuda(), labels.cuda() + outputs, _, _ = net_.forward(inputs, input_target=None, input_middle=None, + adj1_target=adj1_test.cuda(), + adj2_source=adj2_test.cuda(), + adj3_transfer_s2t=adj3_test.cuda(), + adj3_transfer_t2s=adj3_test.transpose(2, 3).cuda(), + adj4_middle=adj4_test.cuda(), + adj5_transfer_s2m=adj5_test.transpose(2, 3).cuda(), + adj6_transfer_t2m=adj6_test.transpose(2, 3).cuda(), + adj5_transfer_m2s=adj5_test.cuda(), + adj6_transfer_m2t=adj6_test.cuda(), ) + # pdb.set_trace() + outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + predictions = torch.max(outputs, 1)[1] + pred_list.append(predictions.cpu()) + label_list.append(labels.squeeze(1).cpu()) + loss = criterion(outputs, labels, batch_average=True) + running_loss_ts += loss.item() + + # total_iou += utils.get_iou(predictions, labels) + + # Print stuff + if ii % num_img_ts == num_img_ts - 1: + # if ii == 10: + miou = get_iou_from_list(pred_list, label_list, n_cls=classes) + running_loss_ts = running_loss_ts / num_img_ts + + print('Validation:') + print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0])) + writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch) + writer.add_scalar('data/test_miour', miou, epoch) + print('Loss: %f' % running_loss_ts) + print('MIoU: %f\n' % miou) + # return miou + + +if __name__ == '__main__': + opts = get_parser() + main(opts) \ No newline at end of file diff --git a/img/messi.jpg b/img/messi.jpg new file mode 100644 index 0000000..eeac0c1 Binary files /dev/null and b/img/messi.jpg differ diff --git a/img/messi_output.png b/img/messi_output.png new file mode 100644 index 0000000..353fc9c Binary files /dev/null and b/img/messi_output.png differ diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/deeplab_xception.py b/networks/deeplab_xception.py new file mode 100644 index 0000000..b068b90 --- /dev/null +++ b/networks/deeplab_xception.py @@ -0,0 +1,684 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn.parameter import Parameter +from collections import OrderedDict + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, + groups=inplanes, bias=bias) + self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +def fixed_padding(inputs, kernel_size, rate): + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) + return padded_inputs + + +class SeparableConv2d_aspp(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): + super(SeparableConv2d_aspp, self).__init__() + + self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, + groups=inplanes, bias=bias) + self.depthwise_bn = nn.BatchNorm2d(inplanes) + self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) + self.pointwise_bn = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + + def forward(self, x): + # x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) + x = self.depthwise(x) + x = self.depthwise_bn(x) + x = self.relu(x) + x = self.pointwise(x) + x = self.pointwise_bn(x) + x = self.relu(x) + return x + +class Decoder_module(nn.Module): + def __init__(self, inplanes, planes, rate=1): + super(Decoder_module, self).__init__() + self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1) + + def forward(self, x): + x = self.atrous_convolution(x) + return x + +class ASPP_module(nn.Module): + def __init__(self, inplanes, planes, rate): + super(ASPP_module, self).__init__() + if rate == 1: + raise RuntimeError() + else: + kernel_size = 3 + padding = rate + self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate, + padding=padding) + + def forward(self, x): + x = self.atrous_convolution(x) + return x + +class ASPP_module_rate0(nn.Module): + def __init__(self, inplanes, planes, rate=1): + super(ASPP_module_rate0, self).__init__() + if rate == 1: + kernel_size = 1 + padding = 0 + self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, + stride=1, padding=padding, dilation=rate, bias=False) + self.bn = nn.BatchNorm2d(planes, eps=1e-5, affine=True) + self.relu = nn.ReLU() + else: + raise RuntimeError() + + def forward(self, x): + x = self.atrous_convolution(x) + x = self.bn(x) + return self.relu(x) + +class SeparableConv2d_same(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): + super(SeparableConv2d_same, self).__init__() + + self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, + groups=inplanes, bias=bias) + self.depthwise_bn = nn.BatchNorm2d(inplanes) + self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) + self.pointwise_bn = nn.BatchNorm2d(planes) + + def forward(self, x): + x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) + x = self.depthwise(x) + x = self.depthwise_bn(x) + x = self.pointwise(x) + x = self.pointwise_bn(x) + return x + +class Block(nn.Module): + def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): + super(Block, self).__init__() + + if planes != inplanes or stride != 1: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False) + if is_last: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False) + self.skipbn = nn.BatchNorm2d(planes) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = inplanes + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + filters = planes + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + + if not start_with_relu: + rep = rep[1:] + + if stride != 1: + rep.append(self.relu) + rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)) + + if is_last: + rep.append(self.relu) + rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation)) + + + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + # print(x.size(),skip.size()) + x += skip + + return x + +class Block2(nn.Module): + def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): + super(Block2, self).__init__() + + if planes != inplanes or stride != 1: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) + self.skipbn = nn.BatchNorm2d(planes) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = inplanes + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + filters = planes + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + + if not start_with_relu: + rep = rep[1:] + + if stride != 1: + self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)]) + + if is_last: + rep.append(SeparableConv2d_same(planes, planes, 3, stride=1)) + + + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + low_middle = x.clone() + x1 = x + x1 = self.block2_lastconv(x1) + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x1 += skip + + return x1,low_middle + +class Xception(nn.Module): + """ + Modified Alighed Xception + """ + def __init__(self, inplanes=3, os=16, pretrained=False): + super(Xception, self).__init__() + + if os == 16: + entry_block3_stride = 2 + middle_block_rate = 1 + exit_block_rates = (1, 2) + elif os == 8: + entry_block3_stride = 1 + middle_block_rate = 2 + exit_block_rates = (2, 4) + else: + raise NotImplementedError + + + # Entry flow + self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64) + + self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) + self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) + self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True) + + # Middle flow + self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + + # Exit flow + self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0], + start_with_relu=True, grow_first=False, is_last=True) + + self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn3 = nn.BatchNorm2d(1536) + + self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn4 = nn.BatchNorm2d(1536) + + self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn5 = nn.BatchNorm2d(2048) + + # Init weights + # self.__init_weight() + + # Load pretrained model + if pretrained: + self.__load_xception_pretrained() + + def forward(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # print('conv1 ',x.size()) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + # print('block1',x.size()) + # low_level_feat = x + x,low_level_feat = self.block2(x) + # print('block2',x.size()) + x = self.block3(x) + # print('xception block3 ',x.size()) + + # Middle flow + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + x = self.block13(x) + x = self.block14(x) + x = self.block15(x) + x = self.block16(x) + x = self.block17(x) + x = self.block18(x) + x = self.block19(x) + + # Exit flow + x = self.block20(x) + x = self.conv3(x) + # x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + # x = self.bn4(x) + x = self.relu(x) + + x = self.conv5(x) + # x = self.bn5(x) + x = self.relu(x) + + return x, low_level_feat + + def __init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def __load_xception_pretrained(self): + pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') + model_dict = {} + state_dict = self.state_dict() + + for k, v in pretrain_dict.items(): + if k in state_dict: + if 'pointwise' in k: + v = v.unsqueeze(-1).unsqueeze(-1) + if k.startswith('block12'): + model_dict[k.replace('block12', 'block20')] = v + elif k.startswith('block11'): + model_dict[k.replace('block11', 'block12')] = v + model_dict[k.replace('block11', 'block13')] = v + model_dict[k.replace('block11', 'block14')] = v + model_dict[k.replace('block11', 'block15')] = v + model_dict[k.replace('block11', 'block16')] = v + model_dict[k.replace('block11', 'block17')] = v + model_dict[k.replace('block11', 'block18')] = v + model_dict[k.replace('block11', 'block19')] = v + elif k.startswith('conv3'): + model_dict[k] = v + elif k.startswith('bn3'): + model_dict[k] = v + model_dict[k.replace('bn3', 'bn4')] = v + elif k.startswith('conv4'): + model_dict[k.replace('conv4', 'conv5')] = v + elif k.startswith('bn4'): + model_dict[k.replace('bn4', 'bn5')] = v + else: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + +class DeepLabv3_plus(nn.Module): + def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True): + if _print: + print("Constructing DeepLabv3+ model...") + print("Number of classes: {}".format(n_classes)) + print("Output stride: {}".format(os)) + print("Number of Input Channels: {}".format(nInputChannels)) + super(DeepLabv3_plus, self).__init__() + + # Atrous Conv + self.xception_features = Xception(nInputChannels, os, pretrained) + + # ASPP + if os == 16: + rates = [1, 6, 12, 18] + elif os == 8: + rates = [1, 12, 24, 36] + raise NotImplementedError + else: + raise NotImplementedError + + self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0]) + self.aspp2 = ASPP_module(2048, 256, rate=rates[1]) + self.aspp3 = ASPP_module(2048, 256, rate=rates[2]) + self.aspp4 = ASPP_module(2048, 256, rate=rates[3]) + + self.relu = nn.ReLU() + + self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(2048, 256, 1, stride=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU() + ) + + self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False) + self.concat_projection_bn1 = nn.BatchNorm2d(256) + + # adopt [1x1, 48] for channel reduction. + self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False) + self.feature_projection_bn1 = nn.BatchNorm2d(48) + + self.decoder = nn.Sequential(Decoder_module(304, 256), + Decoder_module(256, 256) + ) + self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1) + + def forward(self, input): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + def freeze_bn(self): + for m in self.xception_features.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def freeze_totally_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def freeze_aspp_bn(self): + for m in self.aspp1.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp2.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp3.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp4.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def learnable_parameters(self): + layer_features_BN = [] + layer_features = [] + layer_aspp = [] + layer_projection =[] + layer_decoder = [] + layer_other = [] + model_para = list(self.named_parameters()) + for name,para in model_para: + if 'xception' in name: + if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name: + layer_features_BN.append(para) + else: + layer_features.append(para) + # print (name) + elif 'aspp' in name: + layer_aspp.append(para) + elif 'projection' in name: + layer_projection.append(para) + elif 'decode' in name: + layer_decoder.append(para) + elif 'global' not in name: + layer_other.append(para) + return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other + + def get_backbone_para(self): + layer_features = [] + other_features = [] + model_para = list(self.named_parameters()) + for name, para in model_para: + if 'xception' in name: + layer_features.append(para) + else: + other_features.append(para) + + return layer_features, other_features + + def train_fixbn(self, mode=True, freeze_bn=True, freeze_bn_affine=False): + r"""Sets the module in training mode. + + This has any effect only on certain modules. See documentations of + particular modules for details of their behaviors in training/evaluation + mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + Returns: + Module: self + """ + super(DeepLabv3_plus, self).train(mode) + if freeze_bn: + print("Freezing Mean/Var of BatchNorm2D.") + if freeze_bn_affine: + print("Freezing Weight/Bias of BatchNorm2D.") + if freeze_bn: + for m in self.xception_features.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if freeze_bn_affine: + m.weight.requires_grad = False + m.bias.requires_grad = False + # for m in self.aspp1.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.aspp2.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.aspp3.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.aspp4.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.global_avg_pool.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.concat_projection_bn1.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + # for m in self.feature_projection_bn1.modules(): + # if isinstance(m, nn.BatchNorm2d): + # m.eval() + # if freeze_bn_affine: + # m.weight.requires_grad = False + # m.bias.requires_grad = False + + def __init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def load_state_dict_new(self, state_dict): + own_state = self.state_dict() + #for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.','') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print ('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + # raise + # print 'copying %s' %name + # if isinstance(param, own_state): + # backwards compatibility for serialized parameters + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + +def get_1x_lr_params(model): + """ + This generator returns all the parameters of the net except for + the last classification layer. Note that for each batchnorm layer, + requires_grad is set to False in deeplab_resnet.py, therefore this function does not return + any batchnorm parameter + """ + b = [model.xception_features] + for i in range(len(b)): + for k in b[i].parameters(): + if k.requires_grad: + yield k + + +def get_10x_lr_params(model): + """ + This generator returns all the parameters for the last layer of the net, + which does the classification of pixel into classes + """ + b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] + for j in range(len(b)): + for k in b[j].parameters(): + if k.requires_grad: + yield k + + +if __name__ == "__main__": + model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True) + model.eval() + image = torch.randn(1, 3, 512, 512)*255 + with torch.no_grad(): + output = model.forward(image) + print(output.size()) + # print(output) + + + + + + diff --git a/networks/deeplab_xception_synBN.py b/networks/deeplab_xception_synBN.py new file mode 100644 index 0000000..d68312b --- /dev/null +++ b/networks/deeplab_xception_synBN.py @@ -0,0 +1,596 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn.parameter import Parameter +from collections import OrderedDict +from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback, SynchronizedBatchNorm2d + + +def fixed_padding(inputs, kernel_size, rate): + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) + return padded_inputs + +class SeparableConv2d_aspp(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): + super(SeparableConv2d_aspp, self).__init__() + + self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, + groups=inplanes, bias=bias) + self.depthwise_bn = SynchronizedBatchNorm2d(inplanes) + self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) + self.pointwise_bn = SynchronizedBatchNorm2d(planes) + self.relu = nn.ReLU() + + def forward(self, x): + # x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) + x = self.depthwise(x) + x = self.depthwise_bn(x) + x = self.relu(x) + x = self.pointwise(x) + x = self.pointwise_bn(x) + x = self.relu(x) + return x + +class Decoder_module(nn.Module): + def __init__(self, inplanes, planes, rate=1): + super(Decoder_module, self).__init__() + self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1) + + def forward(self, x): + x = self.atrous_convolution(x) + return x + +class ASPP_module(nn.Module): + def __init__(self, inplanes, planes, rate): + super(ASPP_module, self).__init__() + if rate == 1: + raise RuntimeError() + else: + kernel_size = 3 + padding = rate + self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate, + padding=padding) + + def forward(self, x): + x = self.atrous_convolution(x) + return x + + +class ASPP_module_rate0(nn.Module): + def __init__(self, inplanes, planes, rate=1): + super(ASPP_module_rate0, self).__init__() + if rate == 1: + kernel_size = 1 + padding = 0 + self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, + stride=1, padding=padding, dilation=rate, bias=False) + self.bn = SynchronizedBatchNorm2d(planes, eps=1e-5, affine=True) + self.relu = nn.ReLU() + else: + raise RuntimeError() + + def forward(self, x): + x = self.atrous_convolution(x) + x = self.bn(x) + return self.relu(x) + + +class SeparableConv2d_same(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): + super(SeparableConv2d_same, self).__init__() + + self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, + groups=inplanes, bias=bias) + self.depthwise_bn = SynchronizedBatchNorm2d(inplanes) + self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) + self.pointwise_bn = SynchronizedBatchNorm2d(planes) + + def forward(self, x): + x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) + x = self.depthwise(x) + x = self.depthwise_bn(x) + x = self.pointwise(x) + x = self.pointwise_bn(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): + super(Block, self).__init__() + + if planes != inplanes or stride != 1: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False) + if is_last: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False) + self.skipbn = SynchronizedBatchNorm2d(planes) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = inplanes + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + filters = planes + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + + if not start_with_relu: + rep = rep[1:] + + if stride != 1: + rep.append(self.relu) + rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)) + + if is_last: + rep.append(self.relu) + rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation)) + + + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + # print(x.size(),skip.size()) + x += skip + + return x + +class Block2(nn.Module): + def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): + super(Block2, self).__init__() + + if planes != inplanes or stride != 1: + self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) + self.skipbn = SynchronizedBatchNorm2d(planes) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = inplanes + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + filters = planes + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) +# rep.append(nn.BatchNorm2d(planes)) + + if not start_with_relu: + rep = rep[1:] + + if stride != 1: + self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)]) + + if is_last: + rep.append(SeparableConv2d_same(planes, planes, 3, stride=1)) + + + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + low_middle = x.clone() + x1 = x + x1 = self.block2_lastconv(x1) + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x1 += skip + + return x1,low_middle + +class Xception(nn.Module): + """ + Modified Alighed Xception + """ + def __init__(self, inplanes=3, os=16, pretrained=False): + super(Xception, self).__init__() + + if os == 16: + entry_block3_stride = 2 + middle_block_rate = 1 + exit_block_rates = (1, 2) + elif os == 8: + entry_block3_stride = 1 + middle_block_rate = 2 + exit_block_rates = (2, 4) + else: + raise NotImplementedError + + + # Entry flow + self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) + self.bn2 = SynchronizedBatchNorm2d(64) + + self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) + self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) + self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True) + + # Middle flow + self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) + + # Exit flow + self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0], + start_with_relu=True, grow_first=False, is_last=True) + + self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn3 = nn.BatchNorm2d(1536) + + self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn4 = nn.BatchNorm2d(1536) + + self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) + # self.bn5 = nn.BatchNorm2d(2048) + + # Init weights + # self.__init_weight() + + # Load pretrained model + if pretrained: + self.__load_xception_pretrained() + + def forward(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # print('conv1 ',x.size()) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + # print('block1',x.size()) + # low_level_feat = x + x,low_level_feat = self.block2(x) + # print('block2',x.size()) + x = self.block3(x) + # print('xception block3 ',x.size()) + + # Middle flow + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + x = self.block13(x) + x = self.block14(x) + x = self.block15(x) + x = self.block16(x) + x = self.block17(x) + x = self.block18(x) + x = self.block19(x) + + # Exit flow + x = self.block20(x) + x = self.conv3(x) + # x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + # x = self.bn4(x) + x = self.relu(x) + + x = self.conv5(x) + # x = self.bn5(x) + x = self.relu(x) + + return x, low_level_feat + + def __init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def __load_xception_pretrained(self): + pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') + model_dict = {} + state_dict = self.state_dict() + + for k, v in pretrain_dict.items(): + if k in state_dict: + if 'pointwise' in k: + v = v.unsqueeze(-1).unsqueeze(-1) + if k.startswith('block12'): + model_dict[k.replace('block12', 'block20')] = v + elif k.startswith('block11'): + model_dict[k.replace('block11', 'block12')] = v + model_dict[k.replace('block11', 'block13')] = v + model_dict[k.replace('block11', 'block14')] = v + model_dict[k.replace('block11', 'block15')] = v + model_dict[k.replace('block11', 'block16')] = v + model_dict[k.replace('block11', 'block17')] = v + model_dict[k.replace('block11', 'block18')] = v + model_dict[k.replace('block11', 'block19')] = v + elif k.startswith('conv3'): + model_dict[k] = v + elif k.startswith('bn3'): + model_dict[k] = v + model_dict[k.replace('bn3', 'bn4')] = v + elif k.startswith('conv4'): + model_dict[k.replace('conv4', 'conv5')] = v + elif k.startswith('bn4'): + model_dict[k.replace('bn4', 'bn5')] = v + else: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + +class DeepLabv3_plus(nn.Module): + def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True): + if _print: + print("Constructing DeepLabv3+ model...") + print("Number of classes: {}".format(n_classes)) + print("Output stride: {}".format(os)) + print("Number of Input Channels: {}".format(nInputChannels)) + super(DeepLabv3_plus, self).__init__() + + # Atrous Conv + self.xception_features = Xception(nInputChannels, os, pretrained) + + # ASPP + if os == 16: + rates = [1, 6, 12, 18] + elif os == 8: + rates = [1, 12, 24, 36] + else: + raise NotImplementedError + + self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0]) + self.aspp2 = ASPP_module(2048, 256, rate=rates[1]) + self.aspp3 = ASPP_module(2048, 256, rate=rates[2]) + self.aspp4 = ASPP_module(2048, 256, rate=rates[3]) + + self.relu = nn.ReLU() + + self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(2048, 256, 1, stride=1, bias=False), + SynchronizedBatchNorm2d(256), + nn.ReLU() + ) + + self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False) + self.concat_projection_bn1 = SynchronizedBatchNorm2d(256) + + # adopt [1x1, 48] for channel reduction. + self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False) + self.feature_projection_bn1 = SynchronizedBatchNorm2d(48) + + self.decoder = nn.Sequential(Decoder_module(304, 256), + Decoder_module(256, 256) + ) + self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1) + + def forward(self, input): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + def freeze_bn(self): + for m in self.xception_features.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m,SynchronizedBatchNorm2d): + m.eval() + + def freeze_aspp_bn(self): + for m in self.aspp1.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp2.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp3.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + for m in self.aspp4.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def learnable_parameters(self): + layer_features_BN = [] + layer_features = [] + layer_aspp = [] + layer_projection =[] + layer_decoder = [] + layer_other = [] + model_para = list(self.named_parameters()) + for name,para in model_para: + if 'xception' in name: + if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name: + layer_features_BN.append(para) + else: + layer_features.append(para) + # print (name) + elif 'aspp' in name: + layer_aspp.append(para) + elif 'projection' in name: + layer_projection.append(para) + elif 'decode' in name: + layer_decoder.append(para) + else: + layer_other.append(para) + return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other + + + def __init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def load_state_dict_new(self, state_dict): + own_state = self.state_dict() + #for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.','') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print ('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + # raise + # print 'copying %s' %name + # if isinstance(param, own_state): + # backwards compatibility for serialized parameters + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + + + +def get_1x_lr_params(model): + """ + This generator returns all the parameters of the net except for + the last classification layer. Note that for each batchnorm layer, + requires_grad is set to False in deeplab_resnet.py, therefore this function does not return + any batchnorm parameter + """ + b = [model.xception_features] + for i in range(len(b)): + for k in b[i].parameters(): + if k.requires_grad: + yield k + + +def get_10x_lr_params(model): + """ + This generator returns all the parameters for the last layer of the net, + which does the classification of pixel into classes + """ + b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] + for j in range(len(b)): + for k in b[j].parameters(): + if k.requires_grad: + yield k + + +if __name__ == "__main__": + model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True) + model.eval() + # ckt = torch.load('C:\\Users\gaoyi\code_python\deeplab_v3plus.pth') + # model.load_state_dict_new(ckt) + + + image = torch.randn(1, 3, 512, 512)*255 + with torch.no_grad(): + output = model.forward(image) + print(output.size()) + # print(output) + + + + + + diff --git a/networks/deeplab_xception_transfer.py b/networks/deeplab_xception_transfer.py new file mode 100644 index 0000000..f86e424 --- /dev/null +++ b/networks/deeplab_xception_transfer.py @@ -0,0 +1,1003 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn.parameter import Parameter +import numpy as np +from collections import OrderedDict +from torch.nn import Parameter +from networks import deeplab_xception,gcn, deeplab_xception_synBN +import pdb + +####################### +# base model +####################### + +class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus): + def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256): + super(deeplab_xception_transfer_basemodel, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os,) + ### source graph + # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + # nodes=n_classes) + # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # + # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels, + # hidden_layers=hidden_layers, nodes=n_classes + # ) + # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + # nn.ReLU(True)]) + + ### target graph + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + +class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus): + def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256): + super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os,) + ### source graph + + ### target graph + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + +class deeplab_xception_transfer_basemodel_synBN(deeplab_xception_synBN.DeepLabv3_plus): + def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256): + super(deeplab_xception_transfer_basemodel_synBN, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os,) + ### source graph + # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + # nodes=n_classes) + # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # + # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels, + # hidden_layers=hidden_layers, nodes=n_classes + # ) + # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + # nn.ReLU(True)]) + + ### target graph + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + + if 'graph' in name and 'source' not in name and 'target' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + +class deeplab_xception_transfer_basemodel_synBN_savememory(deeplab_xception_synBN.DeepLabv3_plus): + def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256): + super(deeplab_xception_transfer_basemodel_synBN_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os, ) + ### source graph + # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + # nodes=n_classes) + # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + # + # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels, + # hidden_layers=hidden_layers, nodes=n_classes + # ) + # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + # nn.ReLU(True)]) + + ### target graph + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.BatchNorm2d(input_channels), + nn.ReLU(True)]) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + + if 'graph' in name and 'source' not in name and 'target' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + +####################### +# transfer model +####################### + +class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel): + def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256, + transfer_graph=None, source_classes=20): + super(deeplab_xception_transfer_projection, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os, input_channels=input_channels, + hidden_layers=hidden_layers, out_channels=out_channels, ) + self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=source_classes) + self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph, + begin_nodes=source_classes,end_nodes=n_classes) + self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers) + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + # source graph + source_graph = self.source_featuremap_2_graph(x) + source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True) + source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True) + source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True) + + source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True) + source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True) + source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True) + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + source_2_target_graph1 = self.similarity_trans(source_graph1, graph) + # graph combine 1 + # print(graph.size()) + # print(source_2_target_graph1.size()) + # print(source_2_target_graph1_v5.size()) + graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1) + graph = self.fc_graph.forward(graph,relu=True) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph2 = self.similarity_trans(source_graph2, graph) + # graph combine 2 + graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph3 = self.similarity_trans(source_graph3, graph) + # graph combine 3 + graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + + # print(graph.size(),x.size()) + + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + def similarity_trans(self,source,target): + sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2)) + sim = F.softmax(sim, dim=-1) + return torch.matmul(sim, source) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + +class deeplab_xception_transfer_projection_savemem(deeplab_xception_transfer_basemodel_savememory): + def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256, + transfer_graph=None, source_classes=20): + super(deeplab_xception_transfer_projection_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os, input_channels=input_channels, + hidden_layers=hidden_layers, out_channels=out_channels, ) + self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=source_classes) + self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph, + begin_nodes=source_classes,end_nodes=n_classes) + self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers) + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + # source graph + source_graph = self.source_featuremap_2_graph(x) + source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True) + source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True) + source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True) + + source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True) + source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True) + source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True) + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + source_2_target_graph1 = self.similarity_trans(source_graph1, graph) + # graph combine 1 + graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1) + graph = self.fc_graph.forward(graph,relu=True) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph2 = self.similarity_trans(source_graph2, graph) + # graph combine 2 + graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph3 = self.similarity_trans(source_graph3, graph) + # graph combine 3 + graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + + # print(graph.size(),x.size()) + + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + def similarity_trans(self,source,target): + sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2)) + sim = F.softmax(sim, dim=-1) + return torch.matmul(sim, source) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + +class deeplab_xception_transfer_projection_synBN_savemem(deeplab_xception_transfer_basemodel_synBN_savememory): + def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256, + transfer_graph=None, source_classes=20): + super(deeplab_xception_transfer_projection_synBN_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os, input_channels=input_channels, + hidden_layers=hidden_layers, out_channels=out_channels, ) + self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers, + nodes=source_classes) + self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph, + begin_nodes=source_classes,end_nodes=n_classes) + self.fc_graph = gcn.GraphConvolution(hidden_layers*3 ,hidden_layers) + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + # source graph + source_graph = self.source_featuremap_2_graph(x) + source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True) + source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True) + source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True) + + source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True) + source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True) + source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True) + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + source_2_target_graph1 = self.similarity_trans(source_graph1, graph) + # graph combine 1 + graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1) + graph = self.fc_graph.forward(graph,relu=True) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph2 = self.similarity_trans(source_graph2, graph) + # graph combine 2 + graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + + source_2_target_graph3 = self.similarity_trans(source_graph3, graph) + # graph combine 3 + graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1) + graph = self.fc_graph.forward(graph, relu=True) + + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + + # print(graph.size(),x.size()) + + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + def similarity_trans(self,source,target): + sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2)) + sim = F.softmax(sim, dim=-1) + return torch.matmul(sim, source) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + +# if __name__ == '__main__': + # net = deeplab_xception_transfer_projection_v3v5_more_savemem() + # img = torch.rand((2,3,128,128)) + # net.eval() + # a = torch.rand((1,1,7,7)) + # net.forward(img, adj1_target=a) \ No newline at end of file diff --git a/networks/deeplab_xception_universal.py b/networks/deeplab_xception_universal.py new file mode 100644 index 0000000..3545581 --- /dev/null +++ b/networks/deeplab_xception_universal.py @@ -0,0 +1,1077 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict +from torch.nn import Parameter +from networks import deeplab_xception, gcn, deeplab_xception_synBN + + + +class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus): + def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256, + source_classes=20, transfer_graph=None): + super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os,) + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \ + and 'transpose_graph' not in name and 'middle' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### source graph + source_graph = self.source_featuremap_2_graph(x) + + source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True) + source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True) + source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True) + + ### target source + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + +class deeplab_xception_transfer_basemodel_savememory_synbn(deeplab_xception_synBN.DeepLabv3_plus): + def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256, + source_classes=20, transfer_graph=None): + super(deeplab_xception_transfer_basemodel_savememory_synbn, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes, + os=os,) + + + def load_source_model(self,state_dict): + own_state = self.state_dict() + # for name inshop_cos own_state: + # print name + new_state_dict = OrderedDict() + for name, param in state_dict.items(): + name = name.replace('module.', '') + if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \ + and 'transpose_graph' not in name and 'middle' not in name: + if 'featuremap_2_graph' in name: + name = name.replace('featuremap_2_graph','source_featuremap_2_graph') + else: + name = name.replace('graph','source_graph') + new_state_dict[name] = 0 + if name not in own_state: + if 'num_batch' in name: + continue + print('unexpected key "{}" in state_dict' + .format(name)) + continue + # if isinstance(param, own_state): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + try: + own_state[name].copy_(param) + except: + print('While copying the parameter named {}, whose dimensions in the model are' + ' {} and whose dimensions in the checkpoint are {}, ...'.format( + name, own_state[name].size(), param.size())) + continue # i add inshop_cos 2018/02/01 + own_state[name].copy_(param) + # print 'copying %s' %name + + missing = set(own_state.keys()) - set(new_state_dict.keys()) + if len(missing) > 0: + print('missing keys in state_dict: "{}"'.format(missing)) + + def get_target_parameter(self): + l = [] + other = [] + for name, k in self.named_parameters(): + if 'target' in name or 'semantic' in name: + l.append(k) + else: + other.append(k) + return l, other + + def get_semantic_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'semantic' in name: + l.append(k) + return l + + def get_source_parameter(self): + l = [] + for name, k in self.named_parameters(): + if 'source' in name: + l.append(k) + return l + + def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### source graph + source_graph = self.source_featuremap_2_graph(x) + + source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True) + source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True) + source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True) + + ### target source + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + + + def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### add graph + + + # target graph + # print('x size',x.size(),adj1.size()) + graph = self.target_featuremap_2_graph(x) + + # graph combine + # print(graph.size(),source_2_target_graph.size()) + # graph = self.fc_graph.forward(graph,relu=True) + # print(graph.size()) + + graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True) + graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True) + # print(graph.size(),x.size()) + # graph = self.gcn_encode.forward(graph,relu=True) + # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True) + # graph = self.gcn_decode.forward(graph,relu=True) + graph = self.target_graph_2_fea.forward(graph, x) + x = self.target_skip_conv(x) + x = x + graph + + ### + x = self.semantic(x) + x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) + + return x + + +class deeplab_xception_end2end_3d(deeplab_xception_transfer_basemodel_savememory): + def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256, + source_classes=7, middle_classes=18, transfer_graph=None): + super(deeplab_xception_end2end_3d, self).__init__(nInputChannels=nInputChannels, + n_classes=n_classes, + os=os, ) + ### source graph + self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=source_classes) + self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=source_classes + ) + self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + self.source_semantic = nn.Conv2d(out_channels,source_classes,1) + self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1) + + ### target graph 1 + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + ### middle + self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=middle_classes) + self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + ### multi transpose + self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=source_classes, end_nodes=n_classes) + self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=n_classes, end_nodes=source_classes) + + self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=middle_classes, end_nodes=source_classes) + self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=middle_classes, end_nodes=source_classes) + + self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=source_classes, end_nodes=middle_classes) + self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=n_classes, end_nodes=middle_classes) + + + self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + + def freeze_totally_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + + def freeze_backbone_bn(self): + for m in self.xception_features.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + + def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None, + adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### source graph + source_graph = self.source_featuremap_2_graph(x) + ### target source + target_graph = self.target_featuremap_2_graph(x) + ### middle source + middle_graph = self.middle_featuremap_2_graph(x) + + ##### end2end multi task + + ### first task + # print(source_graph.size(),target_graph.size()) + source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True) + target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True) + middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1) + source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1) + # target 2 source middle + target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1) + target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1) + # middle 2 source target + middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1) + middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph1 = torch.cat( + (source_graph1, target_2_source_graph1, target_2_source_graph1_v5, + middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1) + source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True) + # target + target_graph1 = torch.cat( + (target_graph1, source_2_target_graph1, source_2_target_graph1_v5, + middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1) + target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True) + # middle + middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5, + target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1) + middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True) + + + ### seconde task + source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True) + target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True) + middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2) + source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2) + # target 2 source middle + target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2) + target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2) + # middle 2 source target + middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2) + middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph2 = torch.cat( + (source_graph2, target_2_source_graph2, target_2_source_graph2_v5, + middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1) + source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True) + # target + target_graph2 = torch.cat( + (target_graph2, source_2_target_graph2, source_2_target_graph2_v5, + middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1) + target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True) + # middle + middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5, + target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1) + middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True) + + + ### third task + source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True) + target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True) + middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3) + source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3) + # target 2 source middle + target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3) + target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3) + # middle 2 source target + middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3) + middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph3 = torch.cat( + (source_graph3, target_2_source_graph3, target_2_source_graph3_v5, + middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1) + source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True) + # target + target_graph3 = torch.cat( + (target_graph3, source_2_target_graph3, source_2_target_graph3_v5, + middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1) + target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True) + # middle + middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5, + target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1) + middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True) + + return source_graph3, target_graph3, middle_graph3, x + + def similarity_trans(self,source,target): + sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2)) + sim = F.softmax(sim, dim=-1) + return torch.matmul(sim, source) + + def bottom_forward_source(self, input, source_graph): + # print('input size') + # print(input.size()) + # print(source_graph.size()) + graph = self.source_graph_2_fea.forward(source_graph, input) + x = self.source_skip_conv(input) + x = x + graph + x = self.source_semantic(x) + return x + + def bottom_forward_target(self, input, target_graph): + graph = self.target_graph_2_fea.forward(target_graph, input) + x = self.target_skip_conv(input) + x = x + graph + x = self.semantic(x) + return x + + def bottom_forward_middle(self, input, target_graph): + graph = self.middle_graph_2_fea.forward(target_graph, input) + x = self.middle_skip_conv(input) + x = x + graph + x = self.middle_semantic(x) + return x + + def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None, + adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None, + adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,): + if input_source is None and input_target is not None and input_middle is None: + # target + target_batch = input_target.size(0) + input = input_target + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + # source_x = self.bottom_forward_source(source_x, source_graph) + target_x = self.bottom_forward_target(x, target_graph) + + target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return None, target_x, None + + if input_source is not None and input_target is None and input_middle is None: + # source + source_batch = input_source.size(0) + source_list = range(source_batch) + input = input_source + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, + adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + source_x = self.bottom_forward_source(x, source_graph) + source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return source_x, None, None + + if input_middle is not None and input_source is None and input_target is None: + # middle + input = input_middle + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, + adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + middle_x = self.bottom_forward_middle(x, source_graph) + middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return None, None, middle_x + + +class deeplab_xception_end2end_3d_synbn(deeplab_xception_transfer_basemodel_savememory_synbn): + def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256, + source_classes=7, middle_classes=18, transfer_graph=None): + super(deeplab_xception_end2end_3d_synbn, self).__init__(nInputChannels=nInputChannels, + n_classes=n_classes, + os=os, ) + ### source graph + self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=source_classes) + self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=source_classes + ) + self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + self.source_semantic = nn.Conv2d(out_channels,source_classes,1) + self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1) + + ### target graph 1 + self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=n_classes) + self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + ### middle + self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, + hidden_layers=hidden_layers, + nodes=middle_classes) + self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers) + self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers) + + self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, + output_channels=out_channels, + hidden_layers=hidden_layers, nodes=n_classes + ) + self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1), + nn.ReLU(True)]) + + ### multi transpose + self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=source_classes, end_nodes=n_classes) + self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=n_classes, end_nodes=source_classes) + + self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=middle_classes, end_nodes=source_classes) + self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=middle_classes, end_nodes=source_classes) + + self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=source_classes, end_nodes=middle_classes) + self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers, + adj=transfer_graph, + begin_nodes=n_classes, end_nodes=middle_classes) + + + self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers) + + + def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None, + adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,): + x, low_level_features = self.xception_features(input) + # print(x.size()) + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) + + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.concat_projection_conv1(x) + x = self.concat_projection_bn1(x) + x = self.relu(x) + # print(x.size()) + x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) + + low_level_features = self.feature_projection_conv1(low_level_features) + low_level_features = self.feature_projection_bn1(low_level_features) + low_level_features = self.relu(low_level_features) + # print(low_level_features.size()) + # print(x.size()) + x = torch.cat((x, low_level_features), dim=1) + x = self.decoder(x) + + ### source graph + source_graph = self.source_featuremap_2_graph(x) + ### target source + target_graph = self.target_featuremap_2_graph(x) + ### middle source + middle_graph = self.middle_featuremap_2_graph(x) + + ##### end2end multi task + + ### first task + # print(source_graph.size(),target_graph.size()) + source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True) + target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True) + middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1) + source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1) + # target 2 source middle + target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1) + target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1) + # middle 2 source target + middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1) + middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph1 = torch.cat( + (source_graph1, target_2_source_graph1, target_2_source_graph1_v5, + middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1) + source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True) + # target + target_graph1 = torch.cat( + (target_graph1, source_2_target_graph1, source_2_target_graph1_v5, + middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1) + target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True) + # middle + middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5, + target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1) + middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True) + + + ### seconde task + source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True) + target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True) + middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2) + source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2) + # target 2 source middle + target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2) + target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2) + # middle 2 source target + middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2) + middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph2 = torch.cat( + (source_graph2, target_2_source_graph2, target_2_source_graph2_v5, + middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1) + source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True) + # target + target_graph2 = torch.cat( + (target_graph2, source_2_target_graph2, source_2_target_graph2_v5, + middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1) + target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True) + # middle + middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5, + target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1) + middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True) + + + ### third task + source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True) + target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True) + middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True) + + # source 2 target & middle + source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t, + relu=True) + source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m, + relu=True) + # target 2 source & middle + target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s, + relu=True) + target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m, + relu=True) + # middle 2 source & target + middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s, + relu=True) + middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t, + relu=True) + # source 2 middle target + source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3) + source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3) + # target 2 source middle + target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3) + target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3) + # middle 2 source target + middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3) + middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3) + + ## concat + # print(source_graph1.size(), target_2_source_graph1.size(), ) + source_graph3 = torch.cat( + (source_graph3, target_2_source_graph3, target_2_source_graph3_v5, + middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1) + source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True) + # target + target_graph3 = torch.cat( + (target_graph3, source_2_target_graph3, source_2_target_graph3_v5, + middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1) + target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True) + # middle + middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5, + target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1) + middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True) + + return source_graph3, target_graph3, middle_graph3, x + + def similarity_trans(self,source,target): + sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2)) + sim = F.softmax(sim, dim=-1) + return torch.matmul(sim, source) + + def bottom_forward_source(self, input, source_graph): + # print('input size') + # print(input.size()) + # print(source_graph.size()) + graph = self.source_graph_2_fea.forward(source_graph, input) + x = self.source_skip_conv(input) + x = x + graph + x = self.source_semantic(x) + return x + + def bottom_forward_target(self, input, target_graph): + graph = self.target_graph_2_fea.forward(target_graph, input) + x = self.target_skip_conv(input) + x = x + graph + x = self.semantic(x) + return x + + def bottom_forward_middle(self, input, target_graph): + graph = self.middle_graph_2_fea.forward(target_graph, input) + x = self.middle_skip_conv(input) + x = x + graph + x = self.middle_semantic(x) + return x + + def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None, + adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None, + adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,): + + if input_source is None and input_target is not None and input_middle is None: + # target + target_batch = input_target.size(0) + input = input_target + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + # source_x = self.bottom_forward_source(source_x, source_graph) + target_x = self.bottom_forward_target(x, target_graph) + + target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return None, target_x, None + + if input_source is not None and input_target is None and input_middle is None: + # source + source_batch = input_source.size(0) + source_list = range(source_batch) + input = input_source + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, + adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + source_x = self.bottom_forward_source(x, source_graph) + source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return source_x, None, None + + if input_middle is not None and input_source is None and input_target is None: + # middle + input = input_middle + + source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, + adj2_source=adj2_source, + adj3_transfer_s2t=adj3_transfer_s2t, + adj3_transfer_t2s=adj3_transfer_t2s, + adj4_middle=adj4_middle, + adj5_transfer_s2m=adj5_transfer_s2m, + adj6_transfer_t2m=adj6_transfer_t2m, + adj5_transfer_m2s=adj5_transfer_m2s, + adj6_transfer_m2t=adj6_transfer_m2t) + + middle_x = self.bottom_forward_middle(x, source_graph) + middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True) + return None, None, middle_x + + +if __name__ == '__main__': + net = deeplab_xception_end2end_3d() + net.freeze_totally_bn() + img1 = torch.rand((1,3,128,128)) + img2 = torch.rand((1, 3, 128, 128)) + a1 = torch.ones((1,1,7,20)) + a2 = torch.ones((1,1,20,7)) + net.eval() + net.forward(img1,img2,adj3_transfer_t2s=a2,adj3_transfer_s2t=a1) \ No newline at end of file diff --git a/networks/gcn.py b/networks/gcn.py new file mode 100644 index 0000000..cf65805 --- /dev/null +++ b/networks/gcn.py @@ -0,0 +1,271 @@ +import math +import torch +from torch.nn.parameter import Parameter +import torch.nn as nn +import torch.nn.functional as F +from networks import graph +# import pdb + +class GraphConvolution(nn.Module): + + def __init__(self,in_features,out_features,bias=False): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.FloatTensor(in_features,out_features)) + if bias: + self.bias = Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias',None) + # self.reset_parameters() + + def reset_parameters(self): + # stdv = 1./math.sqrt(self.weight(1)) + # self.weight.data.uniform_(-stdv,stdv) + torch.nn.init.xavier_uniform_(self.weight) + # if self.bias is not None: + # self.bias.data.uniform_(-stdv,stdv) + + def forward(self, input,adj=None,relu=False): + support = torch.matmul(input,self.weight) + # print(support.size(),adj.size()) + if adj is not None: + output = torch.matmul(adj,support) + else: + output = support + # print(output.size()) + if self.bias is not None: + return output + self.bias + else: + if relu: + return F.relu(output) + else: + return output + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' + +class Featuremaps_to_Graph(nn.Module): + + def __init__(self,input_channels,hidden_layers,nodes=7): + super(Featuremaps_to_Graph, self).__init__() + self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes)) + self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers)) + # self.reset_parameters() + + def forward(self, input): + n,c,h,w = input.size() + # print('fea input',input.size()) + input1 = input.view(n,c,h*w) + input1 = input1.transpose(1,2) # n x hw x c + # print('fea input1', input1.size()) + ############## Feature maps to node ################ + fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes + weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer + # softmax fea_node + fea_node = F.softmax(fea_node,dim=-1) + # print(fea_node.size(),weight_node.size()) + graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node)) + return graph_node # n x n_class x hidden_layer + + def reset_parameters(self): + for ww in self.parameters(): + torch.nn.init.xavier_uniform_(ww) + # if self.bias is not None: + # self.bias.data.uniform_(-stdv,stdv) + +class Featuremaps_to_Graph_transfer(nn.Module): + + def __init__(self,input_channels,hidden_layers,nodes=7, source_nodes=20): + super(Featuremaps_to_Graph_transfer, self).__init__() + self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes)) + self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers)) + self.pre_fea_transfer = nn.Sequential(*[nn.Linear(source_nodes, source_nodes),nn.LeakyReLU(True), + nn.Linear(source_nodes, nodes), nn.LeakyReLU(True)]) + # self.reset_parameters() + + def forward(self, input, source_pre_fea): + self.pre_fea.data = self.pre_fea_learn(source_pre_fea) + n,c,h,w = input.size() + # print('fea input',input.size()) + input1 = input.view(n,c,h*w) + input1 = input1.transpose(1,2) # n x hw x c + # print('fea input1', input1.size()) + ############## Feature maps to node ################ + fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes + weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer + # softmax fea_node + fea_node = F.softmax(fea_node,dim=-1) + # print(fea_node.size(),weight_node.size()) + graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node)) + return graph_node # n x n_class x hidden_layer + + def pre_fea_learn(self, input): + pre_fea = self.pre_fea_transfer.forward(input.unsqueeze(0)).squeeze(0) + return self.pre_fea.data + pre_fea + +class Graph_to_Featuremaps(nn.Module): + # this is a special version + def __init__(self,input_channels,output_channels,hidden_layers,nodes=7): + super(Graph_to_Featuremaps, self).__init__() + self.node_fea = Parameter(torch.FloatTensor(input_channels+hidden_layers,1)) + self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels)) + # self.reset_parameters() + + def reset_parameters(self): + for ww in self.parameters(): + torch.nn.init.xavier_uniform_(ww) + + def forward(self, input, res_feature): + ''' + + :param input: 1 x batch x nodes x hidden_layer + :param res_feature: batch x channels x h x w + :return: + ''' + batchi,channeli,hi,wi = res_feature.size() + # print(res_feature.size()) + # print(input.size()) + try: + _,batch,nodes,hidden = input.size() + except: + # print(input.size()) + input = input.unsqueeze(0) + _,batch, nodes, hidden = input.size() + + assert batch == batchi + input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden) + res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2) + res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli) + new_fea = torch.cat((res_feature_after_view1,input1),dim=3) + + # print(self.node_fea.size(),new_fea.size()) + new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1 + new_weight = torch.matmul(input, self.weight) # batch x node x channel + new_node = new_node.view(batch, hi*wi, nodes) + feature_out = torch.matmul(new_node,new_weight) + # print(feature_out.size()) + feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size()) + return F.relu(feature_out) + +class Graph_to_Featuremaps_savemem(nn.Module): + # this is a special version for saving gpu memory. The process is same as Graph_to_Featuremaps. + def __init__(self, input_channels, output_channels, hidden_layers, nodes=7): + super(Graph_to_Featuremaps_savemem, self).__init__() + self.node_fea_for_res = Parameter(torch.FloatTensor(input_channels, 1)) + self.node_fea_for_hidden = Parameter(torch.FloatTensor(hidden_layers, 1)) + self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels)) + # self.reset_parameters() + + def reset_parameters(self): + for ww in self.parameters(): + torch.nn.init.xavier_uniform_(ww) + + def forward(self, input, res_feature): + ''' + + :param input: 1 x batch x nodes x hidden_layer + :param res_feature: batch x channels x h x w + :return: + ''' + batchi,channeli,hi,wi = res_feature.size() + # print(res_feature.size()) + # print(input.size()) + try: + _,batch,nodes,hidden = input.size() + except: + # print(input.size()) + input = input.unsqueeze(0) + _,batch, nodes, hidden = input.size() + + assert batch == batchi + input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden) + res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2) + res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli) + # new_fea = torch.cat((res_feature_after_view1,input1),dim=3) + ## sim + new_node1 = torch.matmul(res_feature_after_view1, self.node_fea_for_res) + new_node2 = torch.matmul(input1, self.node_fea_for_hidden) + new_node = new_node1 + new_node2 + ## sim end + # print(self.node_fea.size(),new_fea.size()) + # new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1 + new_weight = torch.matmul(input, self.weight) # batch x node x channel + new_node = new_node.view(batch, hi*wi, nodes) + feature_out = torch.matmul(new_node,new_weight) + # print(feature_out.size()) + feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size()) + return F.relu(feature_out) + + +class Graph_trans(nn.Module): + + def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,adj=None): + super(Graph_trans, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.FloatTensor(in_features,out_features)) + if adj is not None: + h,w = adj.size() + assert (h == end_nodes) and (w == begin_nodes) + self.adj = torch.autograd.Variable(adj,requires_grad=False) + else: + self.adj = Parameter(torch.FloatTensor(end_nodes,begin_nodes)) + if bias: + self.bias = Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias',None) + # self.reset_parameters() + + def reset_parameters(self): + # stdv = 1./math.sqrt(self.weight(1)) + # self.weight.data.uniform_(-stdv,stdv) + torch.nn.init.xavier_uniform_(self.weight) + # if self.bias is not None: + # self.bias.data.uniform_(-stdv,stdv) + + def forward(self, input, relu=False, adj_return=False, adj=None): + support = torch.matmul(input,self.weight) + # print(support.size(),self.adj.size()) + if adj is None: + adj = self.adj + adj1 = self.norm_trans_adj(adj) + output = torch.matmul(adj1,support) + if adj_return: + output1 = F.normalize(output,p=2,dim=-1) + self.adj_mat = torch.matmul(output1,output1.transpose(-2,-1)) + if self.bias is not None: + return output + self.bias + else: + if relu: + return F.relu(output) + else: + return output + + def get_adj_mat(self): + adj = graph.normalize_adj_torch(F.relu(self.adj_mat)) + return adj + + def get_encode_adj(self): + return self.adj + + def norm_trans_adj(self,adj): # maybe can use softmax + adj = F.relu(adj) + r = F.softmax(adj,dim=-1) + # print(adj.size()) + # row_sum = adj.sum(-1).unsqueeze(-1) + # d_mat = row_sum.expand(adj.size()) + # r = torch.div(row_sum,d_mat) + # r[torch.isnan(r)] = 0 + + return r + + +if __name__ == '__main__': + + graph = torch.randn((7,128)) + pred = (torch.rand((7,7))*7).int() + # a = en.forward(graph,pred) + # print(a.size()) \ No newline at end of file diff --git a/networks/graph.py b/networks/graph.py new file mode 100644 index 0000000..8e49059 --- /dev/null +++ b/networks/graph.py @@ -0,0 +1,261 @@ +import numpy as np +import pickle as pkl +import networkx as nx +import scipy.sparse as sp +import torch + +pascal_graph = {0:[0], + 1:[1, 2], + 2:[1, 2, 3, 5], + 3:[2, 3, 4], + 4:[3, 4], + 5:[2, 5, 6], + 6:[5, 6]} + +cihp_graph = {0: [], + 1: [2, 13], + 2: [1, 13], + 3: [14, 15], + 4: [13], + 5: [6, 7, 9, 10, 11, 12, 14, 15], + 6: [5, 7, 10, 11, 14, 15, 16, 17], + 7: [5, 6, 9, 10, 11, 12, 14, 15], + 8: [16, 17, 18, 19], + 9: [5, 7, 10, 16, 17, 18, 19], + 10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17], + 11:[5, 6, 7, 10, 13], + 12:[5, 7, 10, 16, 17], + 13:[1, 2, 4, 10, 11], + 14:[3, 5, 6, 7, 10], + 15:[3, 5, 6, 7, 10], + 16:[6, 8, 9, 10, 12, 18], + 17:[6, 8, 9, 10, 12, 19], + 18:[8, 9, 16], + 19:[8, 9, 17]} + +atr_graph = {0: [], + 1: [2, 11], + 2: [1, 11], + 3: [11], + 4: [5, 6, 7, 11, 14, 15, 17], + 5: [4, 6, 7, 8, 12, 13], + 6: [4,5,7,8,9,10,12,13], + 7: [4,11,12,13,14,15], + 8: [5,6], + 9: [6, 12], + 10:[6, 13], + 11:[1,2,3,4,7,14,15,17], + 12:[5,6,7,9], + 13:[5,6,7,10], + 14:[4,7,11,16], + 15:[4,7,11,16], + 16:[14,15], + 17:[4,11], + } + +cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + +cihp2pascal_nlp_adj = \ + np.array([[ 1., 0.35333052, 0.32727194, 0.17418084, 0.18757584, + 0.40608522, 0.37503981, 0.35448462, 0.22598555, 0.23893579, + 0.33064262, 0.28923404, 0.27986573, 0.4211553 , 0.36915778, + 0.41377746, 0.32485771, 0.37248222, 0.36865639, 0.41500332], + [ 0.39615879, 0.46201529, 0.52321467, 0.30826114, 0.25669527, + 0.54747773, 0.3670523 , 0.3901983 , 0.27519473, 0.3433325 , + 0.52728509, 0.32771333, 0.34819325, 0.63882953, 0.68042925, + 0.69368576, 0.63395791, 0.65344337, 0.59538781, 0.6071375 ], + [ 0.16373166, 0.21663339, 0.3053872 , 0.28377612, 0.1372435 , + 0.4448808 , 0.29479995, 0.31092595, 0.22703953, 0.33983576, + 0.75778818, 0.2619818 , 0.37069392, 0.35184867, 0.49877512, + 0.49979437, 0.51853277, 0.52517541, 0.32517741, 0.32377309], + [ 0.32687232, 0.38482461, 0.37693463, 0.41610834, 0.20415749, + 0.76749079, 0.35139853, 0.3787411 , 0.28411737, 0.35155421, + 0.58792618, 0.31141718, 0.40585111, 0.51189218, 0.82042737, + 0.8342413 , 0.70732188, 0.72752501, 0.60327325, 0.61431337], + [ 0.34069369, 0.34817292, 0.37525998, 0.36497069, 0.17841617, + 0.69746208, 0.31731463, 0.34628951, 0.25167277, 0.32072379, + 0.56711286, 0.24894776, 0.37000453, 0.52600859, 0.82483993, + 0.84966274, 0.7033991 , 0.73449378, 0.56649608, 0.58888791], + [ 0.28477487, 0.35139564, 0.42742352, 0.41664321, 0.20004676, + 0.78566833, 0.42237487, 0.41048549, 0.37933812, 0.46542516, + 0.62444759, 0.3274493 , 0.49466009, 0.49314658, 0.71244233, + 0.71497003, 0.8234787 , 0.83566589, 0.62597135, 0.62626812], + [ 0.3011378 , 0.31775977, 0.42922647, 0.36896257, 0.17597556, + 0.72214655, 0.39162804, 0.38137872, 0.34980296, 0.43818419, + 0.60879174, 0.26762545, 0.46271161, 0.51150476, 0.72318109, + 0.73678399, 0.82620388, 0.84942166, 0.5943811 , 0.60607602]]) + +pascal2atr_nlp_adj = \ + np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522, + 0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639, + 0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778, + 0.41377746, 0.32006291, 0.28923404], + [ 0.39615879, 0.46201529, 0.52321467, 0.25669527, 0.54747773, + 0.34819325, 0.3433325 , 0.26603942, 0.45162929, 0.59538781, + 0.6071375 , 0.63882953, 0.63395791, 0.65344337, 0.68042925, + 0.69368576, 0.44354613, 0.32771333], + [ 0.16373166, 0.21663339, 0.3053872 , 0.1372435 , 0.4448808 , + 0.37069392, 0.33983576, 0.26563416, 0.35443504, 0.32517741, + 0.32377309, 0.35184867, 0.51853277, 0.52517541, 0.49877512, + 0.49979437, 0.21750868, 0.2619818 ], + [ 0.32687232, 0.38482461, 0.37693463, 0.20415749, 0.76749079, + 0.40585111, 0.35155421, 0.28271333, 0.52684576, 0.60327325, + 0.61431337, 0.51189218, 0.70732188, 0.72752501, 0.82042737, + 0.8342413 , 0.40137029, 0.31141718], + [ 0.34069369, 0.34817292, 0.37525998, 0.17841617, 0.69746208, + 0.37000453, 0.32072379, 0.27268885, 0.47426719, 0.56649608, + 0.58888791, 0.52600859, 0.7033991 , 0.73449378, 0.82483993, + 0.84966274, 0.37830796, 0.24894776], + [ 0.28477487, 0.35139564, 0.42742352, 0.20004676, 0.78566833, + 0.49466009, 0.46542516, 0.32662614, 0.55780359, 0.62597135, + 0.62626812, 0.49314658, 0.8234787 , 0.83566589, 0.71244233, + 0.71497003, 0.41223219, 0.3274493 ], + [ 0.3011378 , 0.31775977, 0.42922647, 0.17597556, 0.72214655, + 0.46271161, 0.43818419, 0.3192333 , 0.50979216, 0.5943811 , + 0.60607602, 0.51150476, 0.82620388, 0.84942166, 0.72318109, + 0.73678399, 0.39259827, 0.26762545]]) + +cihp2atr_nlp_adj = np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522, + 0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639, + 0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778, + 0.41377746, 0.32006291, 0.28923404], + [ 0.35333052, 1. , 0.39206695, 0.42143438, 0.4736689 , + 0.47139544, 0.51999208, 0.38354847, 0.45628529, 0.46514124, + 0.50083501, 0.4310595 , 0.39371443, 0.4319752 , 0.42938598, + 0.46384034, 0.44833757, 0.6153155 ], + [ 0.32727194, 0.39206695, 1. , 0.32836702, 0.52603065, + 0.39543695, 0.3622627 , 0.43575346, 0.33866223, 0.45202552, + 0.48421 , 0.53669903, 0.47266611, 0.50925436, 0.42286557, + 0.45403656, 0.37221304, 0.40999322], + [ 0.17418084, 0.46892601, 0.25774838, 0.31816231, 0.39330317, + 0.34218382, 0.48253904, 0.22084125, 0.41335728, 0.52437572, + 0.5191713 , 0.33576117, 0.44230914, 0.44250678, 0.44330833, + 0.43887264, 0.50693611, 0.39278795], + [ 0.18757584, 0.42143438, 0.32836702, 1. , 0.35030067, + 0.30110947, 0.41055555, 0.34338879, 0.34336307, 0.37704433, + 0.38810141, 0.34702081, 0.24171562, 0.25433078, 0.24696241, + 0.2570884 , 0.4465962 , 0.45263213], + [ 0.40608522, 0.4736689 , 0.52603065, 0.35030067, 1. , + 0.54372584, 0.58300258, 0.56674191, 0.555266 , 0.66599594, + 0.68567555, 0.55716359, 0.62997328, 0.65638548, 0.61219615, + 0.63183318, 0.54464151, 0.44293752], + [ 0.37503981, 0.50675565, 0.4761106 , 0.37561813, 0.60419403, + 0.77912403, 0.64595517, 0.85939662, 0.46037144, 0.52348817, + 0.55875094, 0.37741886, 0.455671 , 0.49434392, 0.38479954, + 0.41804074, 0.47285709, 0.57236283], + [ 0.35448462, 0.50576632, 0.51030446, 0.35841033, 0.55106903, + 0.50257274, 0.52591451, 0.4283053 , 0.39991808, 0.42327211, + 0.42853819, 0.42071825, 0.41240559, 0.42259136, 0.38125352, + 0.3868255 , 0.47604934, 0.51811717], + [ 0.22598555, 0.5053299 , 0.36301185, 0.38002282, 0.49700941, + 0.45625243, 0.62876479, 0.4112051 , 0.33944371, 0.48322639, + 0.50318714, 0.29207815, 0.38801966, 0.41119094, 0.29199072, + 0.31021029, 0.41594871, 0.54961962], + [ 0.23893579, 0.51999208, 0.3622627 , 0.41055555, 0.58300258, + 0.68874251, 1. , 0.56977937, 0.49918447, 0.48484363, + 0.51615925, 0.41222306, 0.49535971, 0.53134951, 0.3807616 , + 0.41050298, 0.48675801, 0.51112664], + [ 0.33064262, 0.306412 , 0.60679935, 0.25592294, 0.58738706, + 0.40379627, 0.39679161, 0.33618385, 0.39235148, 0.45474013, + 0.4648476 , 0.59306762, 0.58976007, 0.60778661, 0.55400397, + 0.56551297, 0.3698029 , 0.33860535], + [ 0.28923404, 0.6153155 , 0.40999322, 0.45263213, 0.44293752, + 0.60359359, 0.51112664, 0.46578181, 0.45656936, 0.38142307, + 0.38525582, 0.33327223, 0.35360175, 0.36156453, 0.3384992 , + 0.34261229, 0.49297863, 1. ], + [ 0.27986573, 0.47139544, 0.39543695, 0.30110947, 0.54372584, + 1. , 0.68874251, 0.67765588, 0.48690078, 0.44010641, + 0.44921156, 0.32321099, 0.48311542, 0.4982002 , 0.39378102, + 0.40297733, 0.45309735, 0.60359359], + [ 0.4211553 , 0.4310595 , 0.53669903, 0.34702081, 0.55716359, + 0.32321099, 0.41222306, 0.25721705, 0.36633509, 0.5397475 , + 0.56429928, 1. , 0.55796926, 0.58842844, 0.57930828, + 0.60410597, 0.41615326, 0.33327223], + [ 0.36915778, 0.42938598, 0.42286557, 0.24696241, 0.61219615, + 0.39378102, 0.3807616 , 0.28089866, 0.48450394, 0.77400821, + 0.68813814, 0.57930828, 0.8856886 , 0.81673412, 1. , + 0.92279623, 0.46969152, 0.3384992 ], + [ 0.41377746, 0.46384034, 0.45403656, 0.2570884 , 0.63183318, + 0.40297733, 0.41050298, 0.332879 , 0.48799542, 0.69231828, + 0.77015091, 0.60410597, 0.79788484, 0.88232104, 0.92279623, + 1. , 0.45685017, 0.34261229], + [ 0.32485771, 0.39371443, 0.47266611, 0.24171562, 0.62997328, + 0.48311542, 0.49535971, 0.32477932, 0.51486622, 0.79353556, + 0.69768738, 0.55796926, 1. , 0.92373745, 0.8856886 , + 0.79788484, 0.47883134, 0.35360175], + [ 0.37248222, 0.4319752 , 0.50925436, 0.25433078, 0.65638548, + 0.4982002 , 0.53134951, 0.38057074, 0.52403969, 0.72035243, + 0.78711147, 0.58842844, 0.92373745, 1. , 0.81673412, + 0.88232104, 0.47109935, 0.36156453], + [ 0.36865639, 0.46514124, 0.45202552, 0.37704433, 0.66599594, + 0.44010641, 0.48484363, 0.39636574, 0.50175258, 1. , + 0.91320249, 0.5397475 , 0.79353556, 0.72035243, 0.77400821, + 0.69231828, 0.59087008, 0.38142307], + [ 0.41500332, 0.50083501, 0.48421, 0.38810141, 0.68567555, + 0.44921156, 0.51615925, 0.45156472, 0.50438158, 0.91320249, + 1., 0.56429928, 0.69768738, 0.78711147, 0.68813814, + 0.77015091, 0.57698754, 0.38525582]]) + + + +def normalize_adj(adj): + """Symmetrically normalize adjacency matrix.""" + adj = sp.coo_matrix(adj) + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() + +def preprocess_adj(adj): + """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" + adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy) + adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) # + # return sparse_to_tuple(adj_normalized) + return adj_normalized.todense() + +def row_norm(inputs): + outputs = [] + for x in inputs: + xsum = x.sum() + x = x / xsum + outputs.append(x) + return outputs + + +def normalize_adj_torch(adj): + # print(adj.size()) + if len(adj.size()) == 4: + new_r = torch.zeros(adj.size()).type_as(adj) + for i in range(adj.size(1)): + adj_item = adj[0,i] + rowsum = adj_item.sum(1) + d_inv_sqrt = rowsum.pow_(-0.5) + d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0 + d_mat_inv_sqrt = torch.diag(d_inv_sqrt) + r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt) + new_r[0,i,...] = r + return new_r + rowsum = adj.sum(1) + d_inv_sqrt = rowsum.pow_(-0.5) + d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0 + d_mat_inv_sqrt = torch.diag(d_inv_sqrt) + r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt) + return r + +# def row_norm(adj): + + + + +if __name__ == '__main__': + a= row_norm(cihp2pascal_adj) + print(a) + print(cihp2pascal_adj) + # print(a.shape) diff --git a/requirements b/requirements new file mode 100644 index 0000000..ba99367 --- /dev/null +++ b/requirements @@ -0,0 +1,7 @@ +torchvision +scipy +tensorboardX +numpy +opencv-python +matplotlib +networkx \ No newline at end of file diff --git a/sync_batchnorm/__init__.py b/sync_batchnorm/__init__.py new file mode 100644 index 0000000..bc8709d --- /dev/null +++ b/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/sync_batchnorm/batchnorm.py b/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..5f4e763 --- /dev/null +++ b/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/sync_batchnorm/comm.py b/sync_batchnorm/comm.py new file mode 100644 index 0000000..922f8c4 --- /dev/null +++ b/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/sync_batchnorm/replicate.py b/sync_batchnorm/replicate.py new file mode 100644 index 0000000..b71c7b8 --- /dev/null +++ b/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/sync_batchnorm/unittest.py b/sync_batchnorm/unittest.py new file mode 100644 index 0000000..0675c02 --- /dev/null +++ b/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/train_transfer_cihp.sh b/train_transfer_cihp.sh new file mode 100644 index 0000000..42d2c96 --- /dev/null +++ b/train_transfer_cihp.sh @@ -0,0 +1,2 @@ +python ./exp/transfer/train_cihp_from_pascal.py \ + --batch 24 --gpus 8 --pretrainedModel './pascal_base_trained.pth' \ No newline at end of file diff --git a/train_universal.sh b/train_universal.sh new file mode 100644 index 0000000..f71aff6 --- /dev/null +++ b/train_universal.sh @@ -0,0 +1,3 @@ +python ./exp/universal/pascal_atr_cihp_uni.py \ + --batch 24 --gpus 8 \ + --pretrainedModel './data/pretrained_model/deeplab_v3plus_v3.pth' \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..c83f266 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,5 @@ +from .test_human import get_iou_from_list +import utils + + +__all__ = ['get_iou_from_list','utils'] \ No newline at end of file diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000..754986e --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import random +import math +from torchvision import transforms +from PIL import Image + +__all__ = ['cusSampler','Sampler_uni'] + +'''common N-pairs sampler''' +def index_dataset(dataset): + ''' + get the index according to the dataset type(e.g. pascal or atr or cihp) + :param dataset: + :return: + ''' + return_dict = {} + for i in range(len(dataset)): + tmp_lbl = dataset.datasets_lbl[i] + if tmp_lbl in return_dict: + return_dict[tmp_lbl].append(i) + else : + return_dict[tmp_lbl] = [i] + return return_dict + +def sample_from_class(dataset,class_id): + return dataset[class_id][random.randrange(len(dataset[class_id]))] + +def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]): + images_by_class = index_dataset(dataset) + for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))): + example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size) + for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]] + ] + yield example_indices[:batch_size] + +def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]): + # images_by_class = index_dataset(dataset) + a = label_random_list[random.randrange(len(label_random_list))] + # print(a) + example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size) + for class_label_ind in [a] + ] + return example_indices[:batch_size] + +class cusSampler(torch.utils.data.sampler.Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (sequence): a sequence of indices + """ + + def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]): + self.images_by_class = index_dataset(dataset) + self.batch_size = batchsize + self.dataset = dataset + self.label_random_list = label_random_list + self.len = int(math.ceil(len(dataset) * 1.0 / batchsize)) + + def __iter__(self): + # return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize) + # for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]] + # ] + # print(sampler_(self.images_by_class,self.batch_size,self.dataset)) + return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list)) + + def __len__(self): + return self.len + +def shuffle_cus(d1=20,d2=10,d3=5,batch=2): + return_list = [] + total_num = d1 + d2 + d3 + list1 = list(range(d1)) + batch1 = d1//batch + list2 = list(range(d1,d1+d2)) + batch2 = d2//batch + list3 = list(range(d1+d2,d1+d2+d3)) + batch3 = d3// batch + random.shuffle(list1) + random.shuffle(list2) + random.shuffle(list3) + random_list = list(range(batch1+batch2+batch3)) + random.shuffle(random_list) + for random_batch_index in random_list: + if random_batch_index < batch1: + random_batch_index1 = random_batch_index + return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] + elif random_batch_index < batch1 + batch2: + random_batch_index1 = random_batch_index - batch1 + return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] + else: + random_batch_index1 = random_batch_index - batch1 - batch2 + return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] + return return_list + +def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1): + return_list = [] + total_num = d1 + d2 + d3 + list1 = list(range(d1)) + # batch1 = d1//batch + list2 = list(range(d1,d1+d2)) + # batch2 = d2//batch + list3 = list(range(d1+d2,d1+d2+d3)) + # batch3 = d3// batch + random.shuffle(list1) + random.shuffle(list2) + random.shuffle(list3) + total_list = [list1,list2,list3] + target_list = total_list[balance_index] + for index,list_item in enumerate(total_list): + if index == balance_index: + continue + if len(list_item) > len(target_list): + list_item = list_item[:len(target_list)] + total_list[index] = list_item + list1 = total_list[0] + list2 = total_list[1] + list3 = total_list[2] + # list1 = list(range(d1)) + d1 = len(list1) + batch1 = d1 // batch + # list2 = list(range(d1, d1 + d2)) + d2 = len(list2) + batch2 = d2 // batch + # list3 = list(range(d1 + d2, d1 + d2 + d3)) + d3 = len(list3) + batch3 = d3 // batch + + random_list = list(range(batch1+batch2+batch3)) + random.shuffle(random_list) + for random_batch_index in random_list: + if random_batch_index < batch1: + random_batch_index1 = random_batch_index + return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] + elif random_batch_index < batch1 + batch2: + random_batch_index1 = random_batch_index - batch1 + return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] + else: + random_batch_index1 = random_batch_index - batch1 - batch2 + return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] + return return_list + +class Sampler_uni(torch.utils.data.sampler.Sampler): + def __init__(self, num1, num2, num3, batchsize,balance_id=None): + self.num1 = num1 + self.num2 = num2 + self.num3 = num3 + self.batchsize = batchsize + self.balance_id = balance_id + + def __iter__(self): + if self.balance_id is not None: + rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id) + else: + rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize) + return iter(rlist) + + + def __len__(self): + if self.balance_id is not None: + return self.num1*3 + return self.num1+self.num2+self.num3 diff --git a/utils/test_human.py b/utils/test_human.py new file mode 100644 index 0000000..6243544 --- /dev/null +++ b/utils/test_human.py @@ -0,0 +1,167 @@ +import os +import numpy as np +from PIL import Image + + +def main(): + image_paths, label_paths = init_path() + hist = compute_hist(image_paths, label_paths) + show_result(hist) + + +def init_path(): + list_file = './human/list/val_id.txt' + file_names = [] + with open(list_file, 'rb') as f: + for fn in f: + file_names.append(fn.strip()) + + image_dir = './human/features/attention/val/results/' + label_dir = './human/data/labels/' + + image_paths = [] + label_paths = [] + for file_name in file_names: + image_paths.append(os.path.join(image_dir, file_name + '.png')) + label_paths.append(os.path.join(label_dir, file_name + '.png')) + return image_paths, label_paths + + +def fast_hist(lbl, pred, n_cls): + ''' + compute the miou + :param lbl: label + :param pred: output + :param n_cls: num of class + :return: + ''' + # print(n_cls) + k = (lbl >= 0) & (lbl < n_cls) + return np.bincount(n_cls * lbl[k].astype(int) + pred[k], minlength=n_cls ** 2).reshape(n_cls, n_cls) + + +def compute_hist(images, labels,n_cls=20): + hist = np.zeros((n_cls, n_cls)) + for img_path, label_path in zip(images, labels): + label = Image.open(label_path) + label_array = np.array(label, dtype=np.int32) + image = Image.open(img_path) + image_array = np.array(image, dtype=np.int32) + + gtsz = label_array.shape + imgsz = image_array.shape + if not gtsz == imgsz: + image = image.resize((gtsz[1], gtsz[0]), Image.ANTIALIAS) + image_array = np.array(image, dtype=np.int32) + + hist += fast_hist(label_array, image_array, n_cls) + + return hist + + +def show_result(hist): + classes = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', + 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', + 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', + 'rightShoe'] + # num of correct pixels + num_cor_pix = np.diag(hist) + # num of gt pixels + num_gt_pix = hist.sum(1) + print('=' * 50) + + # @evaluation 1: overall accuracy + acc = num_cor_pix.sum() / hist.sum() + print('>>>', 'overall accuracy', acc) + print('-' * 50) + + # @evaluation 2: mean accuracy & per-class accuracy + print('Accuracy for each class (pixel accuracy):') + for i in range(20): + print('%-15s: %f' % (classes[i], num_cor_pix[i] / num_gt_pix[i])) + acc = num_cor_pix / num_gt_pix + print('>>>', 'mean accuracy', np.nanmean(acc)) + print('-' * 50) + + # @evaluation 3: mean IU & per-class IU + union = num_gt_pix + hist.sum(0) - num_cor_pix + for i in range(20): + print('%-15s: %f' % (classes[i], num_cor_pix[i] / union[i])) + iu = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix) + print('>>>', 'mean IU', np.nanmean(iu)) + print('-' * 50) + + # @evaluation 4: frequency weighted IU + freq = num_gt_pix / hist.sum() + print('>>>', 'fwavacc', (freq[freq > 0] * iu[freq > 0]).sum()) + print('=' * 50) + +def get_iou(pred,lbl,n_cls): + ''' + need tensor cpu + :param pred: + :param lbl: + :param n_cls: + :return: + ''' + hist = np.zeros((n_cls,n_cls)) + for i,j in zip(range(pred.size(0)),range(lbl.size(0))): + pred_item = pred[i].data.numpy() + lbl_item = lbl[j].data.numpy() + hist += fast_hist(lbl_item, pred_item, n_cls) + # num of correct pixels + num_cor_pix = np.diag(hist) + # num of gt pixels + num_gt_pix = hist.sum(1) + union = num_gt_pix + hist.sum(0) - num_cor_pix + # for i in range(20): + # print('%-15s: %f' % (classes[i], num_cor_pix[i] / union[i])) + iu = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix) + print('>>>', 'mean IU', np.nanmean(iu)) + miou = np.nanmean(iu) + print('-' * 50) + return miou + +def get_iou_from_list(pred,lbl,n_cls): + ''' + need tensor cpu + :param pred: list + :param lbl: list + :param n_cls: + :return: + ''' + hist = np.zeros((n_cls,n_cls)) + for i,j in zip(range(len(pred)),range(len(lbl))): + pred_item = pred[i].data.numpy() + lbl_item = lbl[j].data.numpy() + # print(pred_item.shape,lbl_item.shape) + hist += fast_hist(lbl_item, pred_item, n_cls) + + # num of correct pixels + num_cor_pix = np.diag(hist) + # num of gt pixels + num_gt_pix = hist.sum(1) + union = num_gt_pix + hist.sum(0) - num_cor_pix + # for i in range(20): + acc = num_cor_pix.sum() / hist.sum() + print('>>>', 'overall accuracy', acc) + print('-' * 50) + # print('%-15s: %f' % (classes[i], num_cor_pix[i] / union[i])) + iu = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix) + print('>>>', 'mean IU', np.nanmean(iu)) + miou = np.nanmean(iu) + print('-' * 50) + + acc = num_cor_pix / num_gt_pix + print('>>>', 'mean accuracy', np.nanmean(acc)) + print('-' * 50) + + return miou + + +if __name__ == '__main__': + import torch + pred = torch.autograd.Variable(torch.ones((2,1,32,32)).int())*20 + pred2 = torch.autograd.Variable(torch.zeros((2,1, 32, 32)).int()) + # lbl = [torch.zeros((32,32)).int() for _ in range(len(pred))] + get_iou(pred,pred2,7) diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..35c7bb9 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,244 @@ +import os + +import torch +import random +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +def recursive_glob(rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + +def get_cityscapes_labels(): + return np.array([ + # [ 0, 0, 0], + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32]]) + +def get_pascal_labels(): + """Load the mapping that associates pascal classes with label colors + Returns: + np.ndarray with dimensions (21, 3) + """ + return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + +def get_mhp_labels(): + """Load the mapping that associates pascal classes with label colors + Returns: + np.ndarray with dimensions (21, 3) + """ + return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128], # 21 + [96, 0, 0], [0, 96, 0], [96, 96, 0], + [0, 0, 96], [96, 0, 96], [0, 96, 96], [96, 96, 96], + [32, 0, 0], [160, 0, 0], [32, 96, 0], [160, 96, 0], + [32, 0, 96], [160, 0, 96], [32, 96, 96], [160, 96, 96], + [0, 32, 0], [96, 32, 0], [0, 160, 0], [96, 160, 0], + [0, 32, 96], # 41 + [48, 0, 0], [0, 48, 0], [48, 48, 0], + [0, 0, 96], [48, 0, 48], [0, 48, 48], [48, 48, 48], + [16, 0, 0], [80, 0, 0], [16, 48, 0], [80, 48, 0], + [16, 0, 48], [80, 0, 48], [16, 48, 48], [80, 48, 48], + [0, 16, 0], [48, 16, 0], [0, 80, 0], # 59 + + ]) + +def encode_segmap(mask): + """Encode segmentation label images as pascal classes + Args: + mask (np.ndarray): raw segmentation label image of dimension + (M, N, 3), in which the Pascal classes are encoded as colours. + Returns: + (np.ndarray): class map with dimensions (M,N), where the value at + a given location is the integer denoting the class index. + """ + mask = mask.astype(int) + label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) + for ii, label in enumerate(get_pascal_labels()): + label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii + label_mask = label_mask.astype(int) + return label_mask + + +def decode_seg_map_sequence(label_masks, dataset='pascal'): + rgb_masks = [] + for label_mask in label_masks: + rgb_mask = decode_segmap(label_mask, dataset) + rgb_masks.append(rgb_mask) + rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) + return rgb_masks + +def decode_segmap(label_mask, dataset, plot=False): + """Decode segmentation class labels into a color image + Args: + label_mask (np.ndarray): an (M,N) array of integer values denoting + the class label at each spatial location. + plot (bool, optional): whether to show the resulting color image + in a figure. + Returns: + (np.ndarray, optional): the resulting decoded color image. + """ + if dataset == 'pascal': + n_classes = 21 + label_colours = get_pascal_labels() + elif dataset == 'cityscapes': + n_classes = 19 + label_colours = get_cityscapes_labels() + elif dataset == 'mhp': + n_classes = 59 + label_colours = get_mhp_labels() + else: + raise NotImplementedError + + r = label_mask.copy() + g = label_mask.copy() + b = label_mask.copy() + for ll in range(0, n_classes): + r[label_mask == ll] = label_colours[ll, 0] + g[label_mask == ll] = label_colours[ll, 1] + b[label_mask == ll] = label_colours[ll, 2] + rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + if plot: + plt.imshow(rgb) + plt.show() + else: + return rgb + +def generate_param_report(logfile, param): + log_file = open(logfile, 'w') + for key, val in param.items(): + log_file.write(key + ':' + str(val) + '\n') + log_file.close() + +def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): + n, c, h, w = logit.size() + # logit = logit.permute(0, 2, 3, 1) + target = target.squeeze(1) + if weight is None: + criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average) + else: + criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average) + loss = criterion(logit, target.long()) + + return loss + +def cross_entropy2d_dataparallel(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): + n, c, h, w = logit.size() + # logit = logit.permute(0, 2, 3, 1) + target = target.squeeze(1) + if weight is None: + criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average)) + else: + criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average)) + loss = criterion(logit, target.long()) + + return loss.sum() + +def lr_poly(base_lr, iter_, max_iter=100, power=0.9): + return base_lr * ((1 - float(iter_) / max_iter) ** power) + + +def get_iou(pred, gt, n_classes=21): + total_iou = 0.0 + for i in range(len(pred)): + pred_tmp = pred[i] + gt_tmp = gt[i] + + intersect = [0] * n_classes + union = [0] * n_classes + for j in range(n_classes): + match = (pred_tmp == j) + (gt_tmp == j) + + it = torch.sum(match == 2).item() + un = torch.sum(match > 0).item() + + intersect[j] += it + union[j] += un + + iou = [] + for k in range(n_classes): + if union[k] == 0: + continue + iou.append(intersect[k] / union[k]) + + img_iou = (sum(iou) / len(iou)) + total_iou += img_iou + + return total_iou + +def scale_tensor(input,size=512,mode='bilinear'): + print(input.size()) + # b,h,w = input.size() + _, _, h, w = input.size() + if mode == 'nearest': + if h == 512 and w == 512: + return input + return F.upsample_nearest(input,size=(size,size)) + if h>512 and w > 512: + return F.upsample(input, size=(size,size), mode=mode, align_corners=True) + return F.upsample(input, size=(size,size), mode=mode, align_corners=True) + +def scale_tensor_list(input,): + + output = [] + for i in range(len(input)-1): + output_item = [] + for j in range(len(input[i])): + _, _, h, w = input[-1][j].size() + output_item.append(F.upsample(input[i][j], size=(h,w), mode='bilinear', align_corners=True)) + output.append(output_item) + output.append(input[-1]) + return output + +def scale_tensor_list_0(input,base_input): + + output = [] + assert len(input) == len(base_input) + for j in range(len(input)): + _, _, h, w = base_input[j].size() + after_size = F.upsample(input[j], size=(h,w), mode='bilinear', align_corners=True) + base_input[j] = base_input[j] + after_size + # output.append(output_item) + # output.append(input[-1]) + return base_input + +if __name__ == '__main__': + print(lr_poly(0.007,iter_=99,max_iter=150)) \ No newline at end of file