Skip to content

Commit 1376e77

Browse files
authored
Release of v2.24.0
2 parents c72bc70 + 7d1c097 commit 1376e77

File tree

144 files changed

+3651
-376
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

144 files changed

+3651
-376
lines changed

.dev_scripts/gather_models.py

+85-38
Original file line numberDiff line numberDiff line change
@@ -48,44 +48,85 @@ def process_checkpoint(in_file, out_file):
4848
return final_file
4949

5050

51-
def get_final_epoch(config):
51+
def is_by_epoch(config):
5252
cfg = mmcv.Config.fromfile('./configs/' + config)
53-
return cfg.runner.max_epochs
53+
return cfg.runner.type == 'EpochBasedRunner'
5454

5555

56-
def get_best_epoch(exp_dir):
57-
best_epoch_full_path = list(
56+
def get_final_epoch_or_iter(config):
57+
cfg = mmcv.Config.fromfile('./configs/' + config)
58+
if cfg.runner.type == 'EpochBasedRunner':
59+
return cfg.runner.max_epochs
60+
else:
61+
return cfg.runner.max_iters
62+
63+
64+
def get_best_epoch_or_iter(exp_dir):
65+
best_epoch_iter_full_path = list(
5866
sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1]
59-
best_epoch_model_path = best_epoch_full_path.split('/')[-1]
60-
best_epoch = best_epoch_model_path.split('_')[-1].split('.')[0]
61-
return best_epoch_model_path, int(best_epoch)
67+
best_epoch_or_iter_model_path = best_epoch_iter_full_path.split('/')[-1]
68+
best_epoch_or_iter = best_epoch_or_iter_model_path.\
69+
split('_')[-1].split('.')[0]
70+
return best_epoch_or_iter_model_path, int(best_epoch_or_iter)
6271

6372

64-
def get_real_epoch(config):
73+
def get_real_epoch_or_iter(config):
6574
cfg = mmcv.Config.fromfile('./configs/' + config)
66-
epoch = cfg.runner.max_epochs
67-
if cfg.data.train.type == 'RepeatDataset':
68-
epoch *= cfg.data.train.times
69-
return epoch
75+
if cfg.runner.type == 'EpochBasedRunner':
76+
epoch = cfg.runner.max_epochs
77+
if cfg.data.train.type == 'RepeatDataset':
78+
epoch *= cfg.data.train.times
79+
return epoch
80+
else:
81+
return cfg.runner.max_iters
7082

7183

72-
def get_final_results(log_json_path, epoch, results_lut):
84+
def get_final_results(log_json_path,
85+
epoch_or_iter,
86+
results_lut,
87+
by_epoch=True):
7388
result_dict = dict()
89+
last_val_line = None
90+
last_train_line = None
91+
last_val_line_idx = -1
92+
last_train_line_idx = -1
7493
with open(log_json_path, 'r') as f:
75-
for line in f.readlines():
94+
for i, line in enumerate(f.readlines()):
7695
log_line = json.loads(line)
7796
if 'mode' not in log_line.keys():
7897
continue
7998

80-
if log_line['mode'] == 'train' and log_line['epoch'] == epoch:
81-
result_dict['memory'] = log_line['memory']
82-
83-
if log_line['mode'] == 'val' and log_line['epoch'] == epoch:
84-
result_dict.update({
85-
key: log_line[key]
86-
for key in results_lut if key in log_line
87-
})
88-
return result_dict
99+
if by_epoch:
100+
if (log_line['mode'] == 'train'
101+
and log_line['epoch'] == epoch_or_iter):
102+
result_dict['memory'] = log_line['memory']
103+
104+
if (log_line['mode'] == 'val'
105+
and log_line['epoch'] == epoch_or_iter):
106+
result_dict.update({
107+
key: log_line[key]
108+
for key in results_lut if key in log_line
109+
})
110+
return result_dict
111+
else:
112+
if log_line['mode'] == 'train':
113+
last_train_line_idx = i
114+
last_train_line = log_line
115+
116+
if log_line and log_line['mode'] == 'val':
117+
last_val_line_idx = i
118+
last_val_line = log_line
119+
120+
# bug: max_iters = 768, last_train_line['iter'] = 750
121+
assert last_val_line_idx == last_train_line_idx + 1, \
122+
'Log file is incomplete'
123+
result_dict['memory'] = last_train_line['memory']
124+
result_dict.update({
125+
key: last_val_line[key]
126+
for key in results_lut if key in last_val_line
127+
})
128+
129+
return result_dict
89130

90131

91132
def get_dataset_name(config):
@@ -116,10 +157,12 @@ def convert_model_info_to_pwc(model_infos):
116157

117158
# get metadata
118159
memory = round(model['results']['memory'] / 1024, 1)
119-
epochs = get_real_epoch(model['config'])
120160
meta_data = OrderedDict()
121161
meta_data['Training Memory (GB)'] = memory
122-
meta_data['Epochs'] = epochs
162+
if 'epochs' in model:
163+
meta_data['Epochs'] = get_real_epoch_or_iter(model['config'])
164+
else:
165+
meta_data['Iterations'] = get_real_epoch_or_iter(model['config'])
123166
pwc_model_info['Metadata'] = meta_data
124167

125168
# get dataset name
@@ -200,12 +243,14 @@ def main():
200243
model_infos = []
201244
for used_config in used_configs:
202245
exp_dir = osp.join(models_root, used_config)
246+
by_epoch = is_by_epoch(used_config)
203247
# check whether the exps is finished
204248
if args.best is True:
205-
final_model, final_epoch = get_best_epoch(exp_dir)
249+
final_model, final_epoch_or_iter = get_best_epoch_or_iter(exp_dir)
206250
else:
207-
final_epoch = get_final_epoch(used_config)
208-
final_model = 'epoch_{}.pth'.format(final_epoch)
251+
final_epoch_or_iter = get_final_epoch_or_iter(used_config)
252+
final_model = '{}_{}.pth'.format('epoch' if by_epoch else 'iter',
253+
final_epoch_or_iter)
209254

210255
model_path = osp.join(exp_dir, final_model)
211256
# skip if the model is still training
@@ -225,21 +270,23 @@ def main():
225270
for i, key in enumerate(results_lut):
226271
if 'mAP' not in key and 'PQ' not in key:
227272
results_lut[i] = key + 'm_AP'
228-
model_performance = get_final_results(log_json_path, final_epoch,
229-
results_lut)
273+
model_performance = get_final_results(log_json_path,
274+
final_epoch_or_iter, results_lut,
275+
by_epoch)
230276

231277
if model_performance is None:
232278
continue
233279

234280
model_time = osp.split(log_txt_path)[-1].split('.')[0]
235-
model_infos.append(
236-
dict(
237-
config=used_config,
238-
results=model_performance,
239-
epochs=final_epoch,
240-
model_time=model_time,
241-
final_model=final_model,
242-
log_json_path=osp.split(log_json_path)[-1]))
281+
model_info = dict(
282+
config=used_config,
283+
results=model_performance,
284+
model_time=model_time,
285+
final_model=final_model,
286+
log_json_path=osp.split(log_json_path)[-1])
287+
model_info['epochs' if by_epoch else 'iterations'] =\
288+
final_epoch_or_iter
289+
model_infos.append(model_info)
243290

244291
# publish model for each checkpoint
245292
publish_model_infos = []

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
- id: docformatter
4141
args: ["--in-place", "--wrap-descriptions", "79"]
4242
- repo: https://github.com/open-mmlab/pre-commit-hooks
43-
rev: master # Use the ref you want to point at
43+
rev: v0.2.0 # Use the ref you want to point at
4444
hooks:
4545
- id: check-algo-readme
4646
- id: check-copyright

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ This project is released under the [Apache 2.0 license](LICENSE).
7474

7575
## Changelog
7676

77-
**2.23.0** was released in 28/3/2022:
77+
**2.24.0** was released in 26/4/2022:
7878

79-
- Support [Mask2Former](configs/mask2former) and [EfficientNet](configs/efficientnet)
80-
- Support setting data root through environment variable `MMDET_DATASETS`, users don't have to modify the corresponding path in config files anymore.
81-
- Find a good recipe for fine-tuning high precision ResNet backbone pre-trained by Torchvision.
79+
- Support [Simple Copy Paste](configs/simple_copy_paste)
80+
- Support automatically scaling LR according to GPU number and samples per GPU
81+
- Support Class Aware Sampler that improves performance on OpenImages Dataset
8282

8383
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
8484

README_zh-CN.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
7373

7474
## 更新日志
7575

76-
最新的 **2.23.0** 版本已经在 2022.03.28 发布:
76+
最新的 **2.24.0** 版本已经在 2022.03.28 发布:
7777

78-
- 支持 [Mask2Former](configs/mask2former)[Efficientnet](configs/efficientnet)
79-
- 支持通环境变量 `MMDET_DATASETS` 设置数据根目录,因此无需修改配置文件中对应的路径。
80-
- 发现一个很好的方法来微调由 Torchvision 预训练的高精度 ResNet 主干。
78+
- 支持算法 [Simple Copy Paste](configs/simple_copy_paste)
79+
- 支持训练时根据总 batch 数自动缩放学习率
80+
- 支持类别可知的采样器来提高算法在 OpenImages 数据集上的性能
8181

8282
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)
8383

configs/_base_/default_runtime.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@
1919
opencv_num_threads = 0
2020
# set multi-process start method as `fork` to speed up the training
2121
mp_start_method = 'fork'
22+
23+
# Default setting for scaling LR automatically
24+
# - `enable` means enable scaling LR automatically
25+
# or not by default.
26+
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
27+
auto_scale_lr = dict(enable=False, base_batch_size=16)

configs/_base_/models/faster_rcnn_r50_caffe_c4.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
dilation=1,
4343
style='caffe',
4444
norm_cfg=norm_cfg,
45-
norm_eval=True),
45+
norm_eval=True,
46+
init_cfg=dict(
47+
type='Pretrained',
48+
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
4649
bbox_roi_extractor=dict(
4750
type='SingleRoIExtractor',
4851
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
@@ -78,7 +81,7 @@
7881
pos_fraction=0.5,
7982
neg_pos_ub=-1,
8083
add_gt_as_proposals=False),
81-
allowed_border=0,
84+
allowed_border=-1,
8285
pos_weight=-1,
8386
debug=False),
8487
rpn_proposal=dict(

configs/centernet/centernet_resnet18_dcnv2_140e_coco.py

+5
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,8 @@
120120
warmup_ratio=1.0 / 1000,
121121
step=[18, 24]) # the real step is [18*5, 24*5]
122122
runner = dict(max_epochs=28) # the real epoch is 28*5=140
123+
124+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
125+
# USER SHOULD NOT CHANGE ITS VALUES.
126+
# base_batch_size = (8 GPUs) x (16 samples per GPU)
127+
auto_scale_lr = dict(base_batch_size=128)

configs/centripetalnet/centripetalnet_hourglass104_mstest_16x6_210e_coco.py

+5
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,8 @@
103103
warmup_ratio=1.0 / 3,
104104
step=[190])
105105
runner = dict(type='EpochBasedRunner', max_epochs=210)
106+
107+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
108+
# USER SHOULD NOT CHANGE ITS VALUES.
109+
# base_batch_size = (16 GPUs) x (6 samples per GPU)
110+
auto_scale_lr = dict(base_batch_size=96)

configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@
3737
log_config = dict(interval=100)
3838
# For better, more stable performance initialize from COCO
3939
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa
40+
41+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
42+
# USER SHOULD NOT CHANGE ITS VALUES.
43+
# base_batch_size = (8 GPUs) x (1 samples per GPU)
44+
auto_scale_lr = dict(base_batch_size=8)

configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,8 @@
4444
log_config = dict(interval=100)
4545
# For better, more stable performance initialize from COCO
4646
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_1x_coco/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth' # noqa
47+
48+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
49+
# USER SHOULD NOT CHANGE ITS VALUES.
50+
# base_batch_size = (8 GPUs) x (1 samples per GPU)
51+
auto_scale_lr = dict(base_batch_size=8)
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
_base_ = '../_base_/default_runtime.py'
2+
# dataset settings
3+
dataset_type = 'CocoDataset'
4+
data_root = 'data/coco/'
5+
img_norm_cfg = dict(
6+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
7+
image_size = (1024, 1024)
8+
9+
file_client_args = dict(backend='disk')
10+
11+
# Standard Scale Jittering (SSJ) resizes and crops an image
12+
# with a resize range of 0.8 to 1.25 of the original image size.
13+
train_pipeline = [
14+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
15+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
16+
dict(
17+
type='Resize',
18+
img_scale=image_size,
19+
ratio_range=(0.8, 1.25),
20+
multiscale_mode='range',
21+
keep_ratio=True),
22+
dict(
23+
type='RandomCrop',
24+
crop_type='absolute_range',
25+
crop_size=image_size,
26+
recompute_bbox=True,
27+
allow_negative_crop=True),
28+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
29+
dict(type='RandomFlip', flip_ratio=0.5),
30+
dict(type='Normalize', **img_norm_cfg),
31+
dict(type='Pad', size=image_size), # padding to image_size leads 0.5+ mAP
32+
dict(type='DefaultFormatBundle'),
33+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
34+
]
35+
test_pipeline = [
36+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
37+
dict(
38+
type='MultiScaleFlipAug',
39+
img_scale=(1333, 800),
40+
flip=False,
41+
transforms=[
42+
dict(type='Resize', keep_ratio=True),
43+
dict(type='RandomFlip'),
44+
dict(type='Normalize', **img_norm_cfg),
45+
dict(type='Pad', size_divisor=32),
46+
dict(type='ImageToTensor', keys=['img']),
47+
dict(type='Collect', keys=['img']),
48+
])
49+
]
50+
51+
data = dict(
52+
samples_per_gpu=2,
53+
workers_per_gpu=2,
54+
train=dict(
55+
type=dataset_type,
56+
ann_file=data_root + 'annotations/instances_train2017.json',
57+
img_prefix=data_root + 'train2017/',
58+
pipeline=train_pipeline),
59+
val=dict(
60+
type=dataset_type,
61+
ann_file=data_root + 'annotations/instances_val2017.json',
62+
img_prefix=data_root + 'val2017/',
63+
pipeline=test_pipeline),
64+
test=dict(
65+
type=dataset_type,
66+
ann_file=data_root + 'annotations/instances_val2017.json',
67+
img_prefix=data_root + 'val2017/',
68+
pipeline=test_pipeline))
69+
70+
evaluation = dict(interval=6000, metric=['bbox', 'segm'])
71+
72+
# optimizer assumes batch_size = (32 GPUs) x (2 samples per GPU)
73+
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.00004)
74+
optimizer_config = dict(grad_clip=None)
75+
76+
# lr steps at [0.9, 0.95, 0.975] of the maximum iterations
77+
lr_config = dict(
78+
policy='step',
79+
warmup='linear',
80+
warmup_iters=1000,
81+
warmup_ratio=0.001,
82+
step=[243000, 256500, 263250])
83+
checkpoint_config = dict(interval=6000)
84+
# The model is trained by 270k iterations with batch_size 64,
85+
# which is roughly equivalent to 144 epochs.
86+
runner = dict(type='IterBasedRunner', max_iters=270000)
87+
88+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
89+
# USER SHOULD NOT CHANGE ITS VALUES.
90+
# base_batch_size = (32 GPUs) x (2 samples per GPU)
91+
auto_scale_lr = dict(base_batch_size=64)

0 commit comments

Comments
 (0)