forked from Gaoyiminggithub/Graphonomy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2ffd28
commit 366b075
Showing
40 changed files
with
7,881 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,111 @@ | ||
# Graphonomy-Universal-Human-Parsing-via-Graph-Transfer-Learning | ||
coming soon. | ||
# Graphonomy: Universal Human Parsing via Graph Transfer Learning | ||
|
||
This repository contains the code for the paper: | ||
|
||
[**Graphonomy: Universal Human Parsing via Graph Transfer Learning**](https://arxiv.org/abs/1904.04536) | ||
,Ke Gong, Yiming Gao, Xiaodan Liang, Xiaohui Shen, Meng Wang, Liang Lin. | ||
|
||
# Environment and installation | ||
+ Pytorch = 0.4.0 | ||
+ torchvision | ||
+ scipy | ||
+ tensorboardX | ||
+ numpy | ||
+ opencv-python | ||
+ matplotlib | ||
+ networkx | ||
|
||
you can install above package by using `pip install -r requirements.txt` | ||
|
||
# Getting Started | ||
### Data Preparation | ||
+ You need to download the human parsing dataset, prepare the images and store in `/data/datasets/dataset_name/`. | ||
We recommend to symlink the path to the dataets to `/data/dataset/` as follows | ||
|
||
``` | ||
# symlink the Pascal-Person-Part dataset for example | ||
ln -s /path_to_Pascal_Person_Part/* data/datasets/pascal/ | ||
``` | ||
+ The file structure should look like: | ||
``` | ||
/Graphonomy | ||
/data | ||
/datasets | ||
/pascal | ||
/JPEGImages | ||
/list | ||
/SegmentationPart | ||
/CIHP_4w | ||
/Images | ||
/lists | ||
... | ||
``` | ||
|
||
### Inference | ||
We provide a simply script to get the visualization result on the CIHP dataset using [trained](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing) | ||
models as follows : | ||
```shell | ||
# Example of inference | ||
python exp/inference/inference.py \ | ||
--loadmodel /path_to_inference_model \ | ||
--img_path ./img/messi.jpg \ | ||
--output_path ./img/ \ | ||
--output_name /output_file_name | ||
``` | ||
|
||
### Training | ||
#### Transfer learning | ||
1. Download the Pascal pretrained model(avaliable soon). | ||
2. Run the `sh train_transfer_cihp.sh`. | ||
3. The results and models are saved in exp/transfer/run/. | ||
4. Evaluation and visualization script is eval_cihp.sh. You only need to change the attribute of `--loadmodel` before you run it. | ||
|
||
#### Universal training | ||
1. Download the [pretrained](https://drive.google.com/file/d/18WiffKnxaJo50sCC9zroNyHjcnTxGCbk/view?usp=sharing) model and store in /data/pretrained_model/. | ||
2. Run the `sh train_universal.sh`. | ||
3. The results and models are saved in exp/universal/run/. | ||
|
||
### Testing | ||
If you want to evaluate the performance of a pre-trained model on PASCAL-Person-Part or CIHP val/test set, | ||
simply run the script: `sh eval_cihp/pascal.sh`. | ||
Specify the specific model. And we provide the final model that you can download and store it in /data/pretrained_model/. | ||
|
||
### Models | ||
**Pascal-Person-Part trained model** | ||
|
||
|Model|Google Cloud|Baidu Yun| | ||
|--------|--------------|-----------| | ||
|Graphonomy(CIHP)| [Download](https://drive.google.com/file/d/1cwEhlYEzC7jIShENNLnbmcBR0SNlZDE6/view?usp=sharing)| Avaliable soon| | ||
|
||
**CIHP trained model** | ||
|
||
|Model|Google Cloud|Baidu Yun| | ||
|--------|--------------|-----------| | ||
|Graphonomy(PASCAL)| [Download](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing)| Avaliable soon| | ||
|
||
**Universal trained model** | ||
|
||
|Model|Google Cloud|Baidu Yun| | ||
|--------|--------------|-----------| | ||
|Universal|Avaliable soon|Avaliable soon| | ||
|
||
### Todo: | ||
- [ ] release pretrained and trained models | ||
- [ ] update universal eval code&script | ||
|
||
# Citation | ||
|
||
``` | ||
@inproceedings{Gong2019Graphonomy, | ||
author = {Ke Gong and Yiming Gao and Xiaodan Liang and Xiaohui Shen and Meng Wang and Liang Lin}, | ||
title = {Graphonomy: Universal Human Parsing via Graph Transfer Learning}, | ||
booktitle = {CVPR}, | ||
year = {2019}, | ||
} | ||
``` | ||
|
||
# Contact | ||
if you have any questions about this repo, please feel free to contact | ||
[[email protected]](mailto:[email protected]). | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import print_function, division | ||
import os | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from .mypath_atr import Path | ||
import random | ||
from PIL import ImageFile | ||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
class VOCSegmentation(Dataset): | ||
""" | ||
ATR dataset | ||
""" | ||
|
||
def __init__(self, | ||
base_dir=Path.db_root_dir('atr'), | ||
split='train', | ||
transform=None, | ||
flip=False, | ||
): | ||
""" | ||
:param base_dir: path to ATR dataset directory | ||
:param split: train/val | ||
:param transform: transform to apply | ||
""" | ||
super(VOCSegmentation).__init__() | ||
self._flip_flag = flip | ||
|
||
self._base_dir = base_dir | ||
self._image_dir = os.path.join(self._base_dir, 'JPEGImages') | ||
self._cat_dir = os.path.join(self._base_dir, 'SegmentationClassAug') | ||
self._flip_dir = os.path.join(self._base_dir,'SegmentationClassAug_rev') | ||
|
||
if isinstance(split, str): | ||
self.split = [split] | ||
else: | ||
split.sort() | ||
self.split = split | ||
|
||
self.transform = transform | ||
|
||
_splits_dir = os.path.join(self._base_dir, 'list') | ||
|
||
self.im_ids = [] | ||
self.images = [] | ||
self.categories = [] | ||
self.flip_categories = [] | ||
|
||
for splt in self.split: | ||
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: | ||
lines = f.read().splitlines() | ||
|
||
for ii, line in enumerate(lines): | ||
|
||
_image = os.path.join(self._image_dir, line+'.jpg' ) | ||
_cat = os.path.join(self._cat_dir, line +'.png') | ||
_flip = os.path.join(self._flip_dir,line + '.png') | ||
# print(self._image_dir,_image) | ||
assert os.path.isfile(_image) | ||
# print(_cat) | ||
assert os.path.isfile(_cat) | ||
assert os.path.isfile(_flip) | ||
self.im_ids.append(line) | ||
self.images.append(_image) | ||
self.categories.append(_cat) | ||
self.flip_categories.append(_flip) | ||
|
||
|
||
assert (len(self.images) == len(self.categories)) | ||
assert len(self.flip_categories) == len(self.categories) | ||
|
||
# Display stats | ||
print('Number of images in {}: {:d}'.format(split, len(self.images))) | ||
|
||
def __len__(self): | ||
return len(self.images) | ||
|
||
|
||
def __getitem__(self, index): | ||
_img, _target= self._make_img_gt_point_pair(index) | ||
sample = {'image': _img, 'label': _target} | ||
|
||
if self.transform is not None: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
def _make_img_gt_point_pair(self, index): | ||
# Read Image and Target | ||
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) | ||
# _target = np.array(Image.open(self.categories[index])).astype(np.float32) | ||
|
||
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic | ||
if self._flip_flag: | ||
if random.random() < 0.5: | ||
_target = Image.open(self.flip_categories[index]) | ||
_img = _img.transpose(Image.FLIP_LEFT_RIGHT) | ||
else: | ||
_target = Image.open(self.categories[index]) | ||
else: | ||
_target = Image.open(self.categories[index]) | ||
|
||
return _img, _target | ||
|
||
def __str__(self): | ||
return 'ATR(split=' + str(self.split) + ')' | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import print_function, division | ||
import os | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from .mypath_cihp import Path | ||
import random | ||
|
||
class VOCSegmentation(Dataset): | ||
""" | ||
CIHP dataset | ||
""" | ||
|
||
def __init__(self, | ||
base_dir=Path.db_root_dir('cihp'), | ||
split='train', | ||
transform=None, | ||
flip=False, | ||
): | ||
""" | ||
:param base_dir: path to CIHP dataset directory | ||
:param split: train/val/test | ||
:param transform: transform to apply | ||
""" | ||
super(VOCSegmentation).__init__() | ||
self._flip_flag = flip | ||
|
||
self._base_dir = base_dir | ||
self._image_dir = os.path.join(self._base_dir, 'Images') | ||
self._cat_dir = os.path.join(self._base_dir, 'Category_ids') | ||
self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids') | ||
|
||
if isinstance(split, str): | ||
self.split = [split] | ||
else: | ||
split.sort() | ||
self.split = split | ||
|
||
self.transform = transform | ||
|
||
_splits_dir = os.path.join(self._base_dir, 'lists') | ||
|
||
self.im_ids = [] | ||
self.images = [] | ||
self.categories = [] | ||
self.flip_categories = [] | ||
|
||
for splt in self.split: | ||
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: | ||
lines = f.read().splitlines() | ||
|
||
for ii, line in enumerate(lines): | ||
|
||
_image = os.path.join(self._image_dir, line+'.jpg' ) | ||
_cat = os.path.join(self._cat_dir, line +'.png') | ||
_flip = os.path.join(self._flip_dir,line + '.png') | ||
# print(self._image_dir,_image) | ||
assert os.path.isfile(_image) | ||
# print(_cat) | ||
assert os.path.isfile(_cat) | ||
assert os.path.isfile(_flip) | ||
self.im_ids.append(line) | ||
self.images.append(_image) | ||
self.categories.append(_cat) | ||
self.flip_categories.append(_flip) | ||
|
||
|
||
assert (len(self.images) == len(self.categories)) | ||
assert len(self.flip_categories) == len(self.categories) | ||
|
||
# Display stats | ||
print('Number of images in {}: {:d}'.format(split, len(self.images))) | ||
|
||
def __len__(self): | ||
return len(self.images) | ||
|
||
|
||
def __getitem__(self, index): | ||
_img, _target= self._make_img_gt_point_pair(index) | ||
sample = {'image': _img, 'label': _target} | ||
|
||
if self.transform is not None: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
def _make_img_gt_point_pair(self, index): | ||
# Read Image and Target | ||
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) | ||
# _target = np.array(Image.open(self.categories[index])).astype(np.float32) | ||
|
||
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic | ||
if self._flip_flag: | ||
if random.random() < 0.5: | ||
_target = Image.open(self.flip_categories[index]) | ||
_img = _img.transpose(Image.FLIP_LEFT_RIGHT) | ||
else: | ||
_target = Image.open(self.categories[index]) | ||
else: | ||
_target = Image.open(self.categories[index]) | ||
|
||
return _img, _target | ||
|
||
def __str__(self): | ||
return 'CIHP(split=' + str(self.split) + ')' | ||
|
||
|
||
|
Oops, something went wrong.