Skip to content

Commit

Permalink
final submit
Browse files Browse the repository at this point in the history
  • Loading branch information
HiKapok committed May 30, 2018
1 parent e90c5b0 commit 50b53e6
Show file tree
Hide file tree
Showing 15 changed files with 282 additions and 837 deletions.
64 changes: 61 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,23 @@ The pre-trained models of backbone networks can be found here:
- [SE-ResNet50](https://github.com/HiKapok/TF_Se_ResNe_t)
- [SE-ResNeXt50](https://github.com/HiKapok/TF_Se_ResNe_t)

## Introduction

The main goal of this competition is to detect the keypoints of the clothes' image colleted from Alibaba's e-commerce platforms. There are tens of thousands images in total five categories: blouse, outwear, trousers, skirt, dress. The keypoints for each category is defined as follows.

![](demos/outline.jpg "The Keypoints for Each Category")

All the codes was writen by myself and tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04. I tried to use the latest possible TensorFlow's best practice paradigm, like [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator) and [tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers). Almost none py_func was used in my codes to maximize the performance. Augumentations like flip, rotate, random crop, color distort were used to reduce overfit. The current performance of the model is ~0.4% in Normalized Error and got to ~20th-place in the second stage of the competition.
Almost all the codes was writen by myself and tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04. I tried to use the latest possible TensorFlow's best practice paradigm, like [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator) and [tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers). Almost none py_func was used in my codes to maximize the performance. Augumentations like flip, rotate, random crop, color distort were used to reduce overfitting. The current performance of the model is ~0.4% in Normalized Error and got to ~20th-place in the second stage of the competition.

About the model:

- DetNet is better, perform almost the same as SEResNeXt, while SEResNet showed little improvement than ResNet
- Enforce the loss of invisible keypoints to zero gave better performance
- OHKM is useful
- It's bad to do gaussian blur on the predicted heatmap, but it's better to do gaussian blur on the target heatmaps for lower-level prediction
- Ensemble of the heatmaps for fliped images is worser than emsemble of the predictions of fliped images, and do one quarter correction is also useful
- Do cascaded prediction on whole network can eliminate the using of clothes detection network as well as larger input image
- The native hourglass model was the worst but still have great potential, see the top solution of [here](http://human-pose.mpi-inf.mpg.de/#results)

There are still other ways to further improve the performance but I didn't try those in this competition because of their limitations in applications, for example:

Expand All @@ -25,10 +37,56 @@ There are still other ways to further improve the performance but I didn't try t

If you find it's useful to your research or competitions, any contribution or star to this repo is welcomed.

By the way, I'm looking for one computer vision related job recently. I'm very looking forward to your contact if you are interested in.
## Usage
- Download [fashionAI Dataset](https://tianchi.aliyun.com/competition/information.htm?spm=5176.11165261.5678.2.34b72ec5iFguTn&raceId=231648&_lang=en_US) and reorganize the directory as follows:
```
DATA_DIR/
|->train_0/
| |->Annotations/
| | |->annotations.csv
| |->Images/
| | |->blouse
| | |->...
|->train_1/
| |->Annotations/
| | |->annotations.csv
| |->Images/
| | |->blouse
| | |->...
|->...
|->test_0/
| |->test.csv
| |->Images/
| | |->blouse
| | |->...
```
DATA_DIR is your root path of the fashionAI Dataset.
- train_0 -> [update] warm_up_train_20180222.tar
- train_1 -> fashionAI_key_points_train_20180227.tar.gz
- train_2 -> fashionAI_key_points_test_a_20180227.tar
- train_3 -> fashionAI_key_points_test_b_20180418.tgz
- test_0 -> round2_fashionAI_key_points_test_a_20180426.tar
- test_1 -> round2_fashionAI_key_points_test_b_20180601.tar

- set your local dataset path in [config.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/config.py#L20)
- create one file foler named 'model' under the root path of your codes, download all the pre-trained weights of the backbone networks and put them into different sub-folders named 'resnet50', 'seresnet50' and 'seresnext50'. Then start training(set RECORDS_DATA_DIR and TEST_RECORDS_DATA_DIR according to your [config.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/config.py#L20)):
```sh
python train_detxt_cpn_onebyone.py --run_on_cloud=False --data_dir=RECORDS_DATA_DIR
python eval_all_cpn_onepass.py --run_on_cloud=False --backbone=detnext50_cpn --data_dir=TEST_RECORDS_DATA_DIR
```
Submit the generated 'detnext50_cpn_sub.csv' will give you ~0.0427
```sh
python train_senet_cpn_onebyone.py --run_on_cloud=False --data_dir=RECORDS_DATA_DIR
python eval_all_cpn_onepass.py --run_on_cloud=False --backbone=seresnext50_cpn --data_dir=TEST_RECORDS_DATA_DIR
```
Submit the generated 'seresnext50_cpn_sub.csv' will give you ~0.0424

Copy both 'detnext50_cpn_sub.csv' and 'seresnext50_cpn_sub.csv' to a new folder and modify the path and filename in [ensemble_from_csv.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/ensemble_from_csv.py#L27), then run 'python ensemble_from_csv.py' and submit the generated 'ensmeble.csv' will give you ~0.0407.
- training more deeper backbone networks will give better results (+0.001).
- the training of hourglass model is almost the same as above but gave inferior performance

## ##
Some Detection Results:
Some Detection Results (satge one):

- Cascaded Pyramid Network:

Expand Down
2 changes: 1 addition & 1 deletion ensemble_from_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# 'sub_2_hg_4_256_64-half_epoch.csv',
# 'sub_2_hg_8_256_64_v1-half_epoch.csv']#['cpn_2_320_160_1e-3.csv', 'sub_2_hg_4_256_64.csv', 'sub_2_cpn_320_100_1e-3.csv', 'sub_2_hg_8_256_64.csv']

ensemble_subs = ['sext_cpn_flip.csv', 'detxt_cpn_flip.csv']
ensemble_subs = ['large_seresnext_cpn_sub.csv', 'large_detnext_cpn_sub.csv']


def parse_comma_list(args):
Expand Down
29 changes: 15 additions & 14 deletions eval_all_cpn_onepass.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@
'seresnet50_cpn': {'backbone': seresnet_cpn.cascaded_pyramid_net, 'logs_sub_dir': 'logs_se_cpn'},
'seresnext50_cpn': {'backbone': seresnet_cpn.xt_cascaded_pyramid_net, 'logs_sub_dir': 'logs_sext_cpn'},
'detnext50_cpn': {'backbone': detxt_cpn.cascaded_pyramid_net, 'logs_sub_dir': 'logs_detxt_cpn'},
'large_seresnext_cpn': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : seresnet_cpn.xt_cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=50),
'large_seresnext_cpn': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : seresnet_cpn.xt_cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101),
'logs_sub_dir': 'logs_large_sext_cpn'},
'large_detnext_cpn': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : detxt_cpn.cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=50),
'large_detnext_cpn': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : detxt_cpn.cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101),
'logs_sub_dir': 'logs_large_detxt_cpn'},
'head_seresnext50_cpn': {'backbone': seresnet_cpn.head_xt_cascaded_pyramid_net, 'logs_sub_dir': 'logs_head_sext_cpn'},
}
Expand Down Expand Up @@ -164,7 +164,7 @@ def save_image_with_heatmap(image, height, width, heatmap_size, heatmap, predict
imsave(os.path.join(config.EVAL_DEBUG_DIR, file_name), img.astype(np.uint8))
return save_image_with_heatmap.counter

def get_keypoint(image, predictions, heatmap_size, height, width, category, clip_at_zero=True, data_format='channels_last', name=None):
def get_keypoint(image, predictions, heatmap_size, height, width, category, clip_at_zero=False, data_format='channels_last', name=None):
# expand_border = 10
# pad_pred = tf.pad(predictions, tf.constant([[0, 0], [0, 0], [expand_border, expand_border], [expand_border, expand_border]]),
# mode='CONSTANT', name='pred_padding', constant_values=0)
Expand Down Expand Up @@ -242,7 +242,7 @@ def keypoint_model_fn(features, labels, mode, params):
if params['data_format'] == 'channels_last':
pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

pred_x_first_stage, pred_y_first_stage = get_keypoint(image, pred_outputs[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
pred_x_first_stage, pred_y_first_stage = get_keypoint(image, pred_outputs[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])
else:
# test augumentation on the fly
if params['data_format'] == 'channels_last':
Expand Down Expand Up @@ -270,8 +270,8 @@ def cond_flip(heatmap_ind):
pred_outputs = [tf.split(_, 2) for _ in pred_outputs]
pred_outputs_1 = [_[0] for _ in pred_outputs]
pred_outputs_2 = [_[1] for _ in pred_outputs]
pred_x_first_stage1, pred_y_first_stage1 = get_keypoint(image, pred_outputs_1[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
pred_x_first_stage2, pred_y_first_stage2 = get_keypoint(image, pred_outputs_2[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
pred_x_first_stage1, pred_y_first_stage1 = get_keypoint(image, pred_outputs_1[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])
pred_x_first_stage2, pred_y_first_stage2 = get_keypoint(image, pred_outputs_2[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])

dist = tf.pow(tf.pow(pred_x_first_stage1 - pred_x_first_stage2, 2.) + tf.pow(pred_y_first_stage1 - pred_y_first_stage2, 2.), .5)

Expand Down Expand Up @@ -318,7 +318,7 @@ def cond_flip(heatmap_ind):
if params['data_format'] == 'channels_last':
pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

pred_x, pred_y = get_keypoint(image, pred_outputs[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
pred_x, pred_y = get_keypoint(image, pred_outputs[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])
else:
# test augumentation on the fly
with tf.name_scope("refine_prediction"):
Expand Down Expand Up @@ -347,8 +347,9 @@ def cond_flip(heatmap_ind):
pred_outputs = [tf.split(_, 2) for _ in pred_outputs]
pred_outputs_1 = [_[0] for _ in pred_outputs]
pred_outputs_2 = [_[1] for _ in pred_outputs]
pred_x_first_stage1, pred_y_first_stage1 = get_keypoint(image, pred_outputs_1[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
pred_x_first_stage2, pred_y_first_stage2 = get_keypoint(image, pred_outputs_2[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])
#pred_outputs_1[-1] = tf.Print(pred_outputs_1[-1], [pred_outputs_1[-1]], summarize=10000)
pred_x_first_stage1, pred_y_first_stage1 = get_keypoint(image, pred_outputs_1[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])
pred_x_first_stage2, pred_y_first_stage2 = get_keypoint(image, pred_outputs_2[-1], params['heatmap_size'], shape[0][0], shape[0][1], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=False, data_format=params['data_format'])

dist = tf.pow(tf.pow(pred_x_first_stage1 - pred_x_first_stage2, 2.) + tf.pow(pred_y_first_stage1 - pred_y_first_stage2, 2.), .5)

Expand Down Expand Up @@ -435,17 +436,17 @@ def main(_):
#Images/blouse/ab669925e96490ec698af976586f0b2f.jpg
df.loc[cur_record] = [filename, m] + temp_list
cur_record = cur_record + 1
df.to_csv('./{}.csv'.format(m), encoding='utf-8', index=False)
df.to_csv('./{}_{}.csv'.format(FLAGS.backbone.strip(), m), encoding='utf-8', index=False)

# merge dataframe
df_list = [pd.read_csv('./{}.csv'.format(model_to_eval[0]), encoding='utf-8')]
df_list = [pd.read_csv('./{}_{}.csv'.format(FLAGS.backbone.strip(), model_to_eval[0]), encoding='utf-8')]
for m in model_to_eval[1:]:
if m == '': continue
df_list.append(pd.read_csv('./{}.csv'.format(m), encoding='utf-8'))
pd.concat(df_list, ignore_index=True).to_csv('./sub.csv', encoding='utf-8', index=False)
df_list.append(pd.read_csv('./{}_{}.csv'.format(FLAGS.backbone.strip(), m), encoding='utf-8'))
pd.concat(df_list, ignore_index=True).to_csv('./{}_sub.csv'.format(FLAGS.backbone.strip()), encoding='utf-8', index=False)

if FLAGS.run_on_cloud:
tf.gfile.Copy('./sub.csv', os.path.join(full_model_dir, 'sub.csv'), overwrite=True)
tf.gfile.Copy('./{}_sub.csv'.format(FLAGS.backbone.strip()), os.path.join(full_model_dir, '{}_sub.csv'.format(FLAGS.backbone.strip())), overwrite=True)

if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
Expand Down
Loading

0 comments on commit 50b53e6

Please sign in to comment.