Skip to content

Commit b4ac5b0

Browse files
committed
新增cornernet注释
1 parent f4cbce9 commit b4ac5b0

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ mmdetection无疑是非常优异的目标检测框架,但是其整个框架代
5757
- [x] sabl
5858
- [x] reppoints
5959
- [x] reppointsv2
60-
- [ ] cornernet
60+
- [x] cornernet
6161

6262

6363
## 4 模型仓库

docs/changelog.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# ChangeLog
22

33
## V0.0.4
4+
### 2020.11.10
5+
**(1) 新特性**
6+
- 新增cornernet代码和注释
7+
- 新增cornernet文档
8+
49
### 2020.11.9
510
**(1) 新特性**
611
- 新增reppointsv2代码和注释

mmdet/models/dense_heads/corner_head.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self,
125125
super(CornerHead, self).__init__()
126126
self.num_classes = num_classes
127127
self.in_channels = in_channels
128+
# 嵌入向量就1个数就行了
128129
self.corner_emb_channels = corner_emb_channels
129130
self.with_corner_emb = self.corner_emb_channels > 0
130131
self.corner_offset_channels = 2
@@ -405,6 +406,7 @@ def get_targets(self,
405406
label = gt_labels[batch_id][box_id]
406407

407408
# Use coords in the feature level to generate ground truth
409+
# 特征图尺度值的浮点坐标值
408410
scale_left = left * width_ratio
409411
scale_right = right * width_ratio
410412
scale_top = top * height_ratio
@@ -413,6 +415,7 @@ def get_targets(self,
413415
scale_center_y = center_y * height_ratio
414416

415417
# Int coords on feature map/ground truth tensor
418+
# 取整操作
416419
left_idx = int(min(scale_left, width - 1))
417420
right_idx = int(min(scale_right, width - 1))
418421
top_idx = int(min(scale_top, height - 1))
@@ -432,6 +435,7 @@ def get_targets(self,
432435
radius)
433436

434437
# Generate corner offset
438+
# 直接算偏移即可,特征图尺度
435439
left_offset = scale_left - left_idx
436440
top_offset = scale_top - top_idx
437441
right_offset = scale_right - right_idx
@@ -444,6 +448,7 @@ def get_targets(self,
444448

445449
# Generate corner embedding
446450
if with_corner_emb:
451+
# 每一行代表当前gt bbox的两个关键点在特征图上面的坐标
447452
corner_match.append([[top_idx, left_idx],
448453
[bottom_idx, right_idx]])
449454
# Generate guiding shift
@@ -615,6 +620,7 @@ def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off,
615620
# The value of real corner would be 1 in heatmap ground truth.
616621
# The mask is computed in class agnostic mode and its shape is
617622
# batch * 1 * width * height.
623+
# mask是作为权重计算的,只有正样本位置才是1,其余位置全部是0
618624
tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
619625
gt_tl_hmp)
620626
br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(

mmdet/models/losses/ae_loss.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ def ae_loss_per_image(tl_preds, br_preds, match):
3434
push_loss = tl_preds.sum() * 0.
3535
else:
3636
for m in match:
37+
# 同一组
3738
[tl_y, tl_x], [br_y, br_x] = m
39+
# 同一组预测值
3840
tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
3941
br_e = br_preds[:, br_y, br_x].view(-1, 1)
4042
tl_list.append(tl_e)
4143
br_list.append(br_e)
44+
# 同一组预测的平均值
4245
me_list.append((tl_e + br_e) / 2.0)
4346

4447
tl_list = torch.cat(tl_list)
@@ -49,7 +52,7 @@ def ae_loss_per_image(tl_preds, br_preds, match):
4952

5053
# N is object number in image, M is dimension of embedding vector
5154
N, M = tl_list.size()
52-
55+
# 拉的loss
5356
pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
5457
pull_loss = pull_loss.sum() / N
5558

@@ -58,9 +61,11 @@ def ae_loss_per_image(tl_preds, br_preds, match):
5861
# confusion matrix of push loss
5962
conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
6063
conf_weight = 1 - torch.eye(N).type_as(me_list)
64+
# 计算任意组的距离
6165
conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
6266

6367
if N > 1: # more than one object in current image
68+
# 距离要大
6469
push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
6570
else:
6671
push_loss = tl_preds.sum() * 0.
@@ -91,7 +96,9 @@ def forward(self, pred, target, match):
9196
"""Forward function."""
9297
batch = pred.size(0)
9398
pull_all, push_all = 0.0, 0.0
99+
# 单张图片处理
94100
for i in range(batch):
101+
# match是利用label算出来的分组关系,match[i]里面的list每行代表同一组关键点坐标
95102
pull, push = ae_loss_per_image(pred[i], target[i], match[i])
96103

97104
pull_all += self.pull_weight * pull

0 commit comments

Comments
 (0)