Skip to content

Commit

Permalink
add DHN and the fucking stage2_b
Browse files Browse the repository at this point in the history
  • Loading branch information
HiKapok committed Jun 1, 2018
1 parent 50b53e6 commit bc879b6
Show file tree
Hide file tree
Showing 13 changed files with 1,228 additions and 1,895 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Hourglass and CPN model in TensorFlow for 2018-FashionAI Key Points Detection of Apparel at TianChi
# Hourglass, DHN and CPN model in TensorFlow for 2018-FashionAI Key Points Detection of Apparel at TianChi

This repository contains codes of the re-implementent of [Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937) and [Cascaded Pyramid Network for Multi-Person Pose Estimation](https://arxiv.org/abs/1711.07319) in TensorFlow for [FashionAI Global Challenge 2018 - Key Points Detection of Apparel](https://tianchi.aliyun.com/competition/introduction.htm?spm=5176.11409106.5678.1.95b62e48Im9JVH&raceId=231648). The CPN(Cascaded Pyramid Network) here has several different backbones: ResNet50, SE-ResNet50, SE-ResNeXt50, [DetNet](https://arxiv.org/abs/1804.06215) or DetResNeXt50. I have also tried [Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407) to ensemble models on the fly, although limited improvement was achieved.
This repository contains codes of the re-implementent of [Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937), [Simple Baselines for Human Pose Estimation and Tracking (Deconvolution Head Network)](https://arxiv.org/abs/1804.06208) and [Cascaded Pyramid Network for Multi-Person Pose Estimation](https://arxiv.org/abs/1711.07319) in TensorFlow for [FashionAI Global Challenge 2018 - Key Points Detection of Apparel](https://tianchi.aliyun.com/competition/introduction.htm?spm=5176.11409106.5678.1.95b62e48Im9JVH&raceId=231648). Both the CPN(Cascaded Pyramid Network) and DHN (Deconvolution Head Network) here has several different backbones: ResNet50, SE-ResNet50, SE-ResNeXt50, [DetNet](https://arxiv.org/abs/1804.06215) or DetResNeXt50. I have also tried [Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407) to ensemble models on the fly, although limited improvement was achieved.

The pre-trained models of backbone networks can be found here:

Expand All @@ -19,6 +19,7 @@ Almost all the codes was writen by myself and tested under TensorFlow 1.6, Pytho
About the model:

- DetNet is better, perform almost the same as SEResNeXt, while SEResNet showed little improvement than ResNet
- DHN has at least the same performance as CPN, but lack of thorough testing due to the limited time
- Enforce the loss of invisible keypoints to zero gave better performance
- OHKM is useful
- It's bad to do gaussian blur on the predicted heatmap, but it's better to do gaussian blur on the target heatmaps for lower-level prediction
Expand Down Expand Up @@ -66,9 +67,9 @@ If you find it's useful to your research or competitions, any contribution or st
- train_2 -> fashionAI_key_points_test_a_20180227.tar
- train_3 -> fashionAI_key_points_test_b_20180418.tgz
- test_0 -> round2_fashionAI_key_points_test_a_20180426.tar
- test_1 -> round2_fashionAI_key_points_test_b_20180601.tar
- test_1 -> round2_fashionAI_key_points_test_b_20180530.zip.zip

- set your local dataset path in [config.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/config.py#L20)
- set your local dataset path in [config.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/config.py#L20), and then run convert_tfrecords.py to generate *.tfrecords
- create one file foler named 'model' under the root path of your codes, download all the pre-trained weights of the backbone networks and put them into different sub-folders named 'resnet50', 'seresnet50' and 'seresnext50'. Then start training(set RECORDS_DATA_DIR and TEST_RECORDS_DATA_DIR according to your [config.py](https://github.com/HiKapok/tf.fashionAI/blob/e90c5b0072338fa638c56ae788f7146d3f36cb1f/config.py#L20)):
```sh
python train_detxt_cpn_onebyone.py --run_on_cloud=False --data_dir=RECORDS_DATA_DIR
Expand Down
26 changes: 14 additions & 12 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DATA_DIR = '../Datasets'
RECORDS_DATA_DIR = '../Datasets/tfrecords'
TEST_RECORDS_DATA_DIR = '../Datasets/tfrecords_test'
TEST_RECORDS_STAGE2 = '../Datasets/tfrecords_test_stage2'

CATEGORIES = ['blouse', 'dress', 'outwear', 'skirt', 'trousers']
SPLITS = ['test_0', 'train_1', 'train_2', 'train_3']#'train_0',
Expand Down Expand Up @@ -300,30 +301,31 @@
# {'trousers': 10251, 'skirt': 11649, 'blouse': 11109, 'dress': 9002, 'outwear': 9586} 51597
# warm-up {'trousers': 2795, 'skirt': 2292, 'blouse': 2997, 'dress': 2312, 'outwear': 2138} 12534
# test_a {'trousers': 2631, 'skirt': 2683, 'blouse': 2586, 'dress': 2693, 'outwear': 2508} 13101
# test_b {'outwear': 10906, 'trousers': 10618, 'dress': 11096, 'skirt': 11154, 'blouse': 10670} 54444
split_size = {
'*': {'train': 51597+12534,
'val': 0,
'test': 13101,
'test_a': 9970},
'test': 54444,
'test_a': 13101},
'blouse': {'train': 11109+2997,
'val': 0,
'test': 2586,
'test_a': 1974},
'test': 10670,
'test_a': 2586},
'dress': {'train': 9002+2312,
'val': 0,
'test': 2693,
'test_a': 2052},
'test': 11096,
'test_a': 2693},
'outwear': {'train': 9586+2138,
'val': 0,
'test': 2508,
'test_a': 1947},
'test': 10906,
'test_a': 2508},
'skirt': {'train': 11649+2292,
'val': 0,
'test': 2683,
'test_a': 2051},
'test': 11154,
'test_a': 2683},
'trousers': {'train': 10251+2795,
'val': 0,
'test': 2631,
'test_a': 1946},
'test': 10618,
'test_a': 2631},
}

31 changes: 19 additions & 12 deletions convert_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,24 @@ def count_split_examples(split_path, file_pattern=''):

if __name__ == '__main__':
np.random.seed(RANDOM_SEED)
#convert_test('../Datasets/tfrecords_test_stage1_b', splits=['test_stage1_b'])
os.mkdir(config.RECORDS_DATA_DIR)
convert_train(config.RECORDS_DATA_DIR, val_per=0.)
convert_train(config.RECORDS_DATA_DIR, val_per=0., all_splits=config.WARM_UP_SPLITS, file_idx_start=1000)
os.mkdir(config.TEST_RECORDS_DATA_DIR)
convert_test(config.TEST_RECORDS_DATA_DIR)
print('blouse', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='blouse_0000_val')
, 'outwear', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='outwear_0000_val')
, 'dress', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='dress_0000_val')
, 'skirt', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='skirt_0000_val')
, 'trousers', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='trousers_0000_val')
, 'all', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='val'))
convert_test(config.TEST_RECORDS_STAGE2, splits=['test_1'])
print('blouse', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='blouse')
, 'outwear', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='outwear')
, 'dress', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='dress')
, 'skirt', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='skirt')
, 'trousers', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='trousers')
, 'all', count_split_examples(config.TEST_RECORDS_STAGE2, file_pattern='_'))

# os.mkdir(config.RECORDS_DATA_DIR)
# convert_train(config.RECORDS_DATA_DIR, val_per=0.)
# convert_train(config.RECORDS_DATA_DIR, val_per=0., all_splits=config.WARM_UP_SPLITS, file_idx_start=1000)
# os.mkdir(config.TEST_RECORDS_DATA_DIR)
# convert_test(config.TEST_RECORDS_DATA_DIR)
# print('blouse', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='blouse_0000_val')
# , 'outwear', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='outwear_0000_val')
# , 'dress', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='dress_0000_val')
# , 'skirt', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='skirt_0000_val')
# , 'trousers', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='trousers_0000_val')
# , 'all', count_split_examples(config.RECORDS_DATA_DIR, file_pattern='val'))
# test_dataset()

167 changes: 167 additions & 0 deletions depth_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# modified from tensorflow/contrib/layers/python/layers/layers.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope

DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
DATA_FORMAT_NCDHW = 'NCDHW'
DATA_FORMAT_NDHWC = 'NDHWC'

def _model_variable_getter(getter,
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
rename=None,
use_resource=None,
**_):
"""Getter that uses model_variable for compatibility with core layers."""
short_name = name.split('/')[-1]
if rename and short_name in rename:
name_components = name.split('/')
name_components[-1] = rename[short_name]
name = '/'.join(name_components)
return variables.model_variable(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
collections=collections,
trainable=trainable,
caching_device=caching_device,
partitioner=partitioner,
custom_getter=getter,
use_resource=use_resource)


def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""

# VariableScope will nest the getters
def layer_variable_getter(getter, *args, **kwargs):
kwargs['rename'] = rename
return _model_variable_getter(getter, *args, **kwargs)

return layer_variable_getter

def depth_conv2d(
inputs,
kernel_size,
stride=1,
channel_multiplier=1,
padding='SAME',
data_format=DATA_FORMAT_NHWC,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):

if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
layer_variable_getter = _build_variable_getter({
'bias': 'biases',
'depthwise_kernel': 'depthwise_weights'
})

with variable_scope.variable_scope(
scope,
'SeparableConv2d', [inputs],
reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)

df = ('channels_first'
if data_format and data_format.startswith('NC') else 'channels_last')

# Actually apply depthwise conv instead of separable conv.
dtype = inputs.dtype.base_dtype
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
num_filters_in = utils.channel_dimension(
inputs.get_shape(), df, min_rank=4)
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')

depthwise_shape = [kernel_h, kernel_w, num_filters_in, channel_multiplier]
depthwise_weights = variables.model_variable(
'depthwise_weights',
shape=depthwise_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
trainable=trainable,
collections=weights_collections)
strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1]

outputs = nn.depthwise_conv2d(
inputs,
depthwise_weights,
strides,
padding,
rate=utils.two_element_tuple(rate),
data_format=data_format)
num_outputs = num_filters_in

if normalizer_fn is not None:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
else:
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
biases = variables.model_variable(
'biases',
shape=[
num_outputs,
],
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
trainable=trainable,
collections=biases_collections)
outputs = nn.bias_add(outputs, biases, data_format=data_format)

if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
15 changes: 11 additions & 4 deletions eval_all_cpn_onepass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import print_function

import os
import time
import sys
import numpy as np
import pandas as pd
Expand All @@ -27,6 +28,7 @@
from net import detxt_cpn
from net import seresnet_cpn
from net import cpn
from net import simple_xt

from utility import train_helper

Expand All @@ -48,7 +50,7 @@
'gpu_memory_fraction', 1., 'GPU memory fraction to use.')
# scaffold related configuration
tf.app.flags.DEFINE_string(
'data_dir', '../Datasets/tfrecords_test',#tfrecords_test tfrecords_test_stage1_b
'data_dir', '../Datasets/tfrecords_test_stage2',#tfrecords_test tfrecords_test_stage1_b tfrecords_test_stage2
'The directory where the dataset input data is stored.')
tf.app.flags.DEFINE_string(
'dataset_name', '{}_*.tfrecord', 'The pattern of the dataset name to load.')
Expand Down Expand Up @@ -97,7 +99,7 @@
'model_scope', 'blouse',
'Model scope name used to replace the name_scope in checkpoint.')
tf.app.flags.DEFINE_boolean(
'run_on_cloud', True,
'run_on_cloud', False,
'Wether we will train on cloud.')
tf.app.flags.DEFINE_string(
'model_to_eval', 'blouse, dress, outwear, skirt, trousers', #'all, blouse, dress, outwear, skirt, trousers', 'skirt, dress, outwear, trousers',
Expand All @@ -106,6 +108,7 @@
#--model_scope=blouse --checkpoint_path=./logs/blouse
FLAGS = tf.app.flags.FLAGS

#print(FLAGS.data_dir)
all_models = {
'resnet50_cpn': {'backbone': cpn.cascaded_pyramid_net, 'logs_sub_dir': 'logs_cpn'},
'detnet50_cpn': {'backbone': detnet_cpn.cascaded_pyramid_net, 'logs_sub_dir': 'logs_detnet_cpn'},
Expand All @@ -116,6 +119,8 @@
'logs_sub_dir': 'logs_large_sext_cpn'},
'large_detnext_cpn': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : detxt_cpn.cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101),
'logs_sub_dir': 'logs_large_detxt_cpn'},
'simple_net': {'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format : simple_xt.simple_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101),
'logs_sub_dir': 'logs_simple_net'},
'head_seresnext50_cpn': {'backbone': seresnet_cpn.head_xt_cascaded_pyramid_net, 'logs_sub_dir': 'logs_head_sext_cpn'},
}

Expand Down Expand Up @@ -443,10 +448,12 @@ def main(_):
for m in model_to_eval[1:]:
if m == '': continue
df_list.append(pd.read_csv('./{}_{}.csv'.format(FLAGS.backbone.strip(), m), encoding='utf-8'))
pd.concat(df_list, ignore_index=True).to_csv('./{}_sub.csv'.format(FLAGS.backbone.strip()), encoding='utf-8', index=False)

time_stamps = int(time.time())
pd.concat(df_list, ignore_index=True).to_csv('./{}_sub_{}.csv'.format(FLAGS.backbone.strip(), time_stamps), encoding='utf-8', index=False)

if FLAGS.run_on_cloud:
tf.gfile.Copy('./{}_sub.csv'.format(FLAGS.backbone.strip()), os.path.join(full_model_dir, '{}_sub.csv'.format(FLAGS.backbone.strip())), overwrite=True)
tf.gfile.Copy('./{}_sub_{}.csv'.format(FLAGS.backbone.strip(), time_stamps), os.path.join(full_model_dir, '{}_sub_{}.csv'.format(FLAGS.backbone.strip(), time_stamps)), overwrite=True)

if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
Expand Down
Loading

0 comments on commit bc879b6

Please sign in to comment.