Skip to content

Commit f4cbce9

Browse files
committed
新增cornernet
1 parent 522040b commit f4cbce9

File tree

14 files changed

+1791
-7
lines changed

14 files changed

+1791
-7
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ mmdetection无疑是非常优异的目标检测框架,但是其整个框架代
5757
- [x] sabl
5858
- [x] reppoints
5959
- [x] reppointsv2
60+
- [ ] cornernet
6061

6162

6263
## 4 模型仓库
@@ -180,7 +181,9 @@ python image_demo.py demo.jpg ../configs/retinanet/retinanet_r50_fpn_coco.py ../
180181
[第十七篇:mmdetection最小复刻版(十七):语义导向anchor生成](https://www.zybuluo.com/huanghaian/note/1753795)
181182
[第十八篇:mmdetection最小复刻版(十八):Side-Aware边界框定位](https://www.zybuluo.com/huanghaian/note/1753776)
182183
[第十九篇:mmdetection最小复刻版(十九):点集表示法RepPoints](https://www.zybuluo.com/huanghaian/note/1754350)
183-
[第二十篇:mmdetection最小复刻版(二十):加入验证任务的RepPointsV2](https://www.zybuluo.com/huanghaian/note/1754857)
184+
[第二十篇:mmdetection最小复刻版(二十):加入验证任务的RepPointsV2](https://www.zybuluo.com/huanghaian/note/1754857)
185+
[第二十一篇:mmdetection最小复刻版(二十一):关键点检测思路CornerNet分析](https://www.zybuluo.com/huanghaian/note/1755495)
186+
184187

185188
## other
186189

configs/cornernet/README.md

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# CornerNet
2+
3+
## Introduction
4+
```
5+
@inproceedings{law2018cornernet,
6+
title={Cornernet: Detecting objects as paired keypoints},
7+
author={Law, Hei and Deng, Jia},
8+
booktitle={15th European Conference on Computer Vision, ECCV 2018},
9+
pages={765--781},
10+
year={2018},
11+
organization={Springer Verlag}
12+
}
13+
```
14+
15+
## Results and models
16+
17+
| Backbone | Batch Size | Step/Total Epochs | Mem (GB) | Inf time (fps) | box AP | Download |
18+
| :-------------: | :--------: |:----------------: | :------: | :------------: | :----: | :------: |
19+
| HourglassNet-104 | [10 x 5](./cornernet_hourglass104_mstest_10x5_210e_coco.py) | 180/210 | 13.9 | 4.2 | 41.2 | [model](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco/cornernet_hourglass104_mstest_10x5_210e_coco_20200824_185720-5fefbf1c.pth) | [log](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco/cornernet_hourglass104_mstest_10x5_210e_coco_20200824_185720.log.json) |
20+
| HourglassNet-104 | [8 x 6](./cornernet_hourglass104_mstest_8x6_210e_coco.py) | 180/210 | 15.9 | 4.2 | 41.2 | [model](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_8x6_210e_coco/cornernet_hourglass104_mstest_8x6_210e_coco_20200825_150618-79b44c30.pth) | [log](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_8x6_210e_coco/cornernet_hourglass104_mstest_8x6_210e_coco_20200825_150618.log.json) |
21+
| HourglassNet-104 | [32 x 3](./cornernet_hourglass104_mstest_32x3_210e_coco.py) | 180/210 | 9.5 | 3.9 | 40.4 | [model](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_32x3_210e_coco/cornernet_hourglass104_mstest_32x3_210e_coco_20200819_203110-1efaea91.pth) | [log](http://download.openmmlab.com/mmdetection/v2.0/cornernet/cornernet_hourglass104_mstest_32x3_210e_coco/cornernet_hourglass104_mstest_32x3_210e_coco_20200819_203110.log.json) |
22+
23+
Note:
24+
- TTA setting is single-scale and `flip=True`.
25+
- Experiments with `images_per_gpu=6` are conducted on Tesla V100-SXM2-32GB, `images_per_gpu=3` are conducted on GeForce GTX 1080 Ti.
26+
- Here are the descriptions of each experiment setting:
27+
- 10 x 5: 10 GPUs with 5 images per gpu. This is the same setting as that reported in the original paper.
28+
- 8 x 6: 8 GPUs with 6 images per gpu. The total batchsize is similar to paper and only need 1 node to train.
29+
- 32 x 3: 32 GPUs with 3 images per gpu. The default setting for 1080TI and need 4 nodes to train.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
_base_ = [
2+
'../_base_/default_runtime.py', '../_base_/datasets/coco_detection.py'
3+
]
4+
5+
# model settings
6+
model = dict(
7+
type='CornerNet',
8+
backbone=dict(
9+
type='HourglassNet',
10+
downsample_times=5,
11+
num_stacks=2,
12+
stage_channels=[256, 256, 384, 384, 384, 512],
13+
stage_blocks=[2, 2, 2, 2, 2, 4],
14+
norm_cfg=dict(type='BN', requires_grad=True)),
15+
neck=None,
16+
bbox_head=dict(
17+
type='CornerHead',
18+
num_classes=80,
19+
in_channels=256,
20+
num_feat_levels=2,
21+
corner_emb_channels=1,
22+
loss_heatmap=dict(
23+
type='GaussianFocalLoss', alpha=2.0, gamma=4.0, loss_weight=1),
24+
loss_embedding=dict(
25+
type='AssociativeEmbeddingLoss',
26+
pull_weight=0.10,
27+
push_weight=0.10),
28+
loss_offset=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1)))
29+
# data settings
30+
img_norm_cfg = dict(
31+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
32+
train_pipeline = [
33+
dict(type='LoadImageFromFile', to_float32=True),
34+
dict(type='LoadAnnotations', with_bbox=True),
35+
dict(
36+
type='PhotoMetricDistortion',
37+
brightness_delta=32,
38+
contrast_range=(0.5, 1.5),
39+
saturation_range=(0.5, 1.5),
40+
hue_delta=18),
41+
dict(
42+
type='RandomCenterCropPad',
43+
crop_size=(511, 511),
44+
ratios=(0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3),
45+
test_mode=False,
46+
test_pad_mode=None,
47+
**img_norm_cfg),
48+
dict(type='Resize', img_scale=(511, 511), keep_ratio=False),
49+
dict(type='RandomFlip', flip_ratio=0.5),
50+
dict(type='Normalize', **img_norm_cfg),
51+
dict(type='DefaultFormatBundle'),
52+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
53+
]
54+
test_pipeline = [
55+
dict(type='LoadImageFromFile', to_float32=True),
56+
dict(
57+
type='MultiScaleFlipAug',
58+
scale_factor=1.0,
59+
flip=True,
60+
transforms=[
61+
dict(type='Resize'),
62+
dict(
63+
type='RandomCenterCropPad',
64+
crop_size=None,
65+
ratios=None,
66+
border=None,
67+
test_mode=True,
68+
test_pad_mode=['logical_or', 127],
69+
**img_norm_cfg),
70+
dict(type='RandomFlip'),
71+
dict(type='Normalize', **img_norm_cfg),
72+
dict(type='ImageToTensor', keys=['img']),
73+
dict(
74+
type='Collect',
75+
keys=['img'],
76+
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape',
77+
'scale_factor', 'flip', 'img_norm_cfg', 'border')),
78+
])
79+
]
80+
data = dict(
81+
samples_per_gpu=3,
82+
workers_per_gpu=0,
83+
train=dict(pipeline=train_pipeline),
84+
val=dict(pipeline=test_pipeline),
85+
test=dict(pipeline=test_pipeline))
86+
# training and testing settings
87+
train_cfg = None
88+
test_cfg = dict(
89+
corner_topk=100,
90+
local_maximum_kernel=3,
91+
distance_threshold=0.5,
92+
score_thr=0.05,
93+
max_per_img=100,
94+
nms_cfg=dict(type='soft_nms', iou_threshold=0.5, method='gaussian'))
95+
# optimizer
96+
optimizer = dict(type='Adam', lr=0.0005)
97+
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
98+
# learning policy
99+
lr_config = dict(
100+
policy='step',
101+
warmup='linear',
102+
warmup_iters=500,
103+
warmup_ratio=1.0 / 3,
104+
step=[180])
105+
total_epochs = 210

info

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cuda编译: python setup.py develop
2020
../configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py
2121
../configs/reppoints/bbox_r50_grid_center_fpn_gn-neck+head_1x_coco.py
2222
../configs/reppointv2/reppoints_v2_r50_fpn_1x_coco.py
23+
../configs/cornernet/cornernet_hourglass104_mstest_32x3_210e_coco.py
2324

2425

2526

mmdet/models/backbones/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from .rr_tiny_yolov4_backbone import RRTinyYolov4Backbone
77
from .rr_yolov5_backbone import RRYoloV5Backbone
88
from .ssd_vgg import SSDVGG
9+
from .hourglass import HourglassNet
910

1011
__all__ = [
1112
'ResNet', 'ResNetV1d', 'Darknet', 'RRDarknet53', 'RRTinyYolov3Backbone',
12-
'RRCSPDarknet53', 'RRTinyYolov4Backbone', 'SSDVGG'
13+
'RRCSPDarknet53', 'RRTinyYolov4Backbone', 'SSDVGG', 'HourglassNet'
1314
]

mmdet/models/backbones/hourglass.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import torch.nn as nn
2+
from mmdet.cv_core.cnn import ConvModule
3+
4+
from ..builder import BACKBONES
5+
from ..utils import ResLayer
6+
from .resnet import BasicBlock
7+
8+
9+
class HourglassModule(nn.Module):
10+
"""Hourglass Module for HourglassNet backbone.
11+
12+
Generate module recursively and use BasicBlock as the base unit.
13+
14+
Args:
15+
depth (int): Depth of current HourglassModule.
16+
stage_channels (list[int]): Feature channels of sub-modules in current
17+
and follow-up HourglassModule.
18+
stage_blocks (list[int]): Number of sub-modules stacked in current and
19+
follow-up HourglassModule.
20+
norm_cfg (dict): Dictionary to construct and config norm layer.
21+
"""
22+
23+
def __init__(self,
24+
depth,
25+
stage_channels,
26+
stage_blocks,
27+
norm_cfg=dict(type='BN', requires_grad=True)):
28+
super(HourglassModule, self).__init__()
29+
30+
self.depth = depth
31+
32+
cur_block = stage_blocks[0]
33+
next_block = stage_blocks[1]
34+
35+
cur_channel = stage_channels[0]
36+
next_channel = stage_channels[1]
37+
38+
self.up1 = ResLayer(
39+
BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
40+
41+
self.low1 = ResLayer(
42+
BasicBlock,
43+
cur_channel,
44+
next_channel,
45+
cur_block,
46+
stride=2,
47+
norm_cfg=norm_cfg)
48+
49+
if self.depth > 1:
50+
self.low2 = HourglassModule(depth - 1, stage_channels[1:],
51+
stage_blocks[1:])
52+
else:
53+
self.low2 = ResLayer(
54+
BasicBlock,
55+
next_channel,
56+
next_channel,
57+
next_block,
58+
norm_cfg=norm_cfg)
59+
60+
self.low3 = ResLayer(
61+
BasicBlock,
62+
next_channel,
63+
cur_channel,
64+
cur_block,
65+
norm_cfg=norm_cfg,
66+
downsample_first=False)
67+
68+
self.up2 = nn.Upsample(scale_factor=2)
69+
70+
def forward(self, x):
71+
"""Forward function."""
72+
up1 = self.up1(x)
73+
low1 = self.low1(x)
74+
low2 = self.low2(low1)
75+
low3 = self.low3(low2)
76+
up2 = self.up2(low3)
77+
return up1 + up2
78+
79+
80+
@BACKBONES.register_module()
81+
class HourglassNet(nn.Module):
82+
"""HourglassNet backbone.
83+
84+
Stacked Hourglass Networks for Human Pose Estimation.
85+
More details can be found in the `paper
86+
<https://arxiv.org/abs/1603.06937>`_ .
87+
88+
Args:
89+
downsample_times (int): Downsample times in a HourglassModule.
90+
num_stacks (int): Number of HourglassModule modules stacked,
91+
1 for Hourglass-52, 2 for Hourglass-104.
92+
stage_channels (list[int]): Feature channel of each sub-module in a
93+
HourglassModule.
94+
stage_blocks (list[int]): Number of sub-modules stacked in a
95+
HourglassModule.
96+
feat_channel (int): Feature channel of conv after a HourglassModule.
97+
norm_cfg (dict): Dictionary to construct and config norm layer.
98+
99+
Example:
100+
>>> from mmdet.models import HourglassNet
101+
>>> import torch
102+
>>> self = HourglassNet()
103+
>>> self.eval()
104+
>>> inputs = torch.rand(1, 3, 511, 511)
105+
>>> level_outputs = self.forward(inputs)
106+
>>> for level_output in level_outputs:
107+
... print(tuple(level_output.shape))
108+
(1, 256, 128, 128)
109+
(1, 256, 128, 128)
110+
"""
111+
112+
def __init__(self,
113+
downsample_times=5,
114+
num_stacks=2,
115+
stage_channels=(256, 256, 384, 384, 384, 512),
116+
stage_blocks=(2, 2, 2, 2, 2, 4),
117+
feat_channel=256,
118+
norm_cfg=dict(type='BN', requires_grad=True)):
119+
super(HourglassNet, self).__init__()
120+
121+
self.num_stacks = num_stacks
122+
assert self.num_stacks >= 1
123+
assert len(stage_channels) == len(stage_blocks)
124+
assert len(stage_channels) > downsample_times
125+
126+
cur_channel = stage_channels[0]
127+
128+
self.stem = nn.Sequential(
129+
ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
130+
ResLayer(BasicBlock, 128, 256, 1, stride=2, norm_cfg=norm_cfg))
131+
132+
self.hourglass_modules = nn.ModuleList([
133+
HourglassModule(downsample_times, stage_channels, stage_blocks)
134+
for _ in range(num_stacks)
135+
])
136+
137+
self.inters = ResLayer(
138+
BasicBlock,
139+
cur_channel,
140+
cur_channel,
141+
num_stacks - 1,
142+
norm_cfg=norm_cfg)
143+
144+
self.conv1x1s = nn.ModuleList([
145+
ConvModule(
146+
cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
147+
for _ in range(num_stacks - 1)
148+
])
149+
150+
self.out_convs = nn.ModuleList([
151+
ConvModule(
152+
cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
153+
for _ in range(num_stacks)
154+
])
155+
156+
self.remap_convs = nn.ModuleList([
157+
ConvModule(
158+
feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
159+
for _ in range(num_stacks - 1)
160+
])
161+
162+
self.relu = nn.ReLU(inplace=True)
163+
164+
def init_weights(self, pretrained=None):
165+
"""Init module weights.
166+
167+
We do nothing in this function because all modules we used
168+
(ConvModule, BasicBlock and etc.) have default initialization, and
169+
currently we don't provide pretrained model of HourglassNet.
170+
171+
Detector's __init__() will call backbone's init_weights() with
172+
pretrained as input, so we keep this function.
173+
"""
174+
# Training Centripetal Model needs to reset parameters for Conv2d
175+
for m in self.modules():
176+
if isinstance(m, nn.Conv2d):
177+
m.reset_parameters()
178+
179+
def forward(self, x):
180+
"""Forward function."""
181+
inter_feat = self.stem(x)
182+
out_feats = []
183+
184+
for ind in range(self.num_stacks):
185+
single_hourglass = self.hourglass_modules[ind]
186+
out_conv = self.out_convs[ind]
187+
188+
hourglass_feat = single_hourglass(inter_feat)
189+
out_feat = out_conv(hourglass_feat)
190+
out_feats.append(out_feat)
191+
192+
if ind < self.num_stacks - 1:
193+
inter_feat = self.conv1x1s[ind](
194+
inter_feat) + self.remap_convs[ind](
195+
out_feat)
196+
inter_feat = self.inters[ind](self.relu(inter_feat))
197+
198+
return out_feats

mmdet/models/dense_heads/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from .sabl_retina_head import SABLRetinaHead
2121
from .reppoints_head import RepPointsHead
2222
from .reppoints_v2_head import RepPointsV2Head
23+
from .corner_head import CornerHead
2324

2425
__all__ = [
2526
'RPNHead', 'RPNTestMixin', 'AnchorHead', 'RetinaHead', 'YOLOV3Head', 'RRYolov3Head', 'RRTinyYolov3Head',
2627
'RRTinyYolov4Head', 'RRYolov5Head', 'SSDHead', 'VFNetHead', 'GARetinaHead', 'GuidedAnchorHead',
2728
'AnchorFreeHead', 'FCOSHead', 'ATSSHead', 'GFLHead', 'PISARetinaHead', 'PAAHead', 'SABLRetinaHead',
28-
'RepPointsHead', 'RepPointsV2Head'
29+
'RepPointsHead', 'RepPointsV2Head', 'CornerHead'
2930
]

0 commit comments

Comments
 (0)