Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Supported Methods
- [x] [FGFA](configs/vid/fgfa) (ICCV 2017)
- [x] [SELSA](configs/vid/selsa) (ICCV 2019)
- [x] [Temporal RoI Align](configs/vid/temporal_roi_align) (AAAI 2021)
- [x] [TF-Blender](configs/vid/tf_blender) (ICCV 2021)

Supported Datasets

Expand Down
24 changes: 24 additions & 0 deletions configs/vid/tf_blender/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# TF-Blender: Temporal Feature Blender for Video Object Detection

## Abstract

<!-- [ABSTRACT] -->

Video objection detection is a challenging task because isolated video frames may encounter appearance deterioration, which introduces great confusion for detection. One of the popular solutions is to exploit the temporal information and enhance per-frame representation through aggregating features from neighboring frames. Despite achieving improvements in detection, existing methods focus on the selection of higher-level video frames for aggregation rather than modeling lower-level temporal relations to increase the feature representation. To address this limitation, we propose a novel solution named TF-Blender, which includes three modules: 1) Temporal relation models the relations between the current frame and its neighboring frames to preserve spatial information. 2). Feature adjustment enriches the representation of every neighboring feature map; 3) Feature blender combines outputs from the first two modules and produces stronger features for the later detection tasks. For its simplicity, TFBlender can be effortlessly plugged into any detection network to improve detection behavior. Extensive evaluations on ImageNet VID and YouTube-VIS benchmarks indicate the performance guarantees of using TF-Blender on recent state-of-the-art methods.

<!-- [IMAGE] -->

## Citation

<!-- [ALGORITHM] -->

```latex
@inproceedings{cui2021tf,
title={Tf-blender: Temporal feature blender for video object detection},
author={Cui, Yiming and Yan, Liqi and Cao, Zhiwen and Liu, Dongfang},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={8138--8147},
year={2021}
}
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = ['./fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py']
model = dict(
detector=dict(
backbone=dict(
depth=101,
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet101'))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
_base_ = [
'../../_base_/models/faster_rcnn_r50_dc5.py',
'../../_base_/datasets/imagenet_vid_fgfa_style.py',
'../../_base_/default_runtime.py'
]
model = dict(
type='FGFA',
motion=dict(
type='FlowNetSimple',
img_scale_factor=0.5,
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmtracking/pretrained_weights/flownet_simple.pth' # noqa: E501
)),
aggregator=dict(
type='TFBlenderAggregator', num_convs=1, channels=512, kernel_size=3))

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[2, 5])

# runtime settings
total_epochs = 7
evaluation = dict(metric=['bbox'], interval=7)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = ['./fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py']
model = dict(
detector=dict(
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://resnext101_64x4d'))))
3 changes: 2 additions & 1 deletion mmtrack/models/aggregators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .embed_aggregator import EmbedAggregator
from .selsa_aggregator import SelsaAggregator
from .tfblender_aggregator import TFBlenderAggregator

__all__ = ['EmbedAggregator', 'SelsaAggregator']
__all__ = ['EmbedAggregator', 'SelsaAggregator', 'TFBlenderAggregator']
11 changes: 7 additions & 4 deletions mmtrack/models/aggregators/selsa_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ class SelsaAggregator(BaseModule):
Object Detection". `SELSA <https://arxiv.org/abs/1907.06390>`_.

Args:
in_channels (int): The number of channels of the features of
proposal.
num_attention_blocks (int): The number of attention blocks used in
selsa aggregator module. Defaults to 16.
num_convs (int): Number of embedding convs.
channels (int): Channels of embedding convs. Defaults to 256.
kernel_size (int): Kernel size of embedding convs, Defaults to 3.
norm_cfg (dict): Configuration of normlization method after each
conv. Defaults to None.
act_cfg (dict): Configuration of activation method after each
conv. Defaults to dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
Expand Down
130 changes: 130 additions & 0 deletions mmtrack/models/aggregators/tfblender_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
import torch.nn as nn
from mmcv.cnn.bricks import ConvModule

from ..builder import AGGREGATORS


@AGGREGATORS.register_module()
class TFBlenderAggregator(nn.Module):
"""TF-Blender aggregator module.

This module is proposed in "TF-Blender: Temporal Feature Blender for Video
Object Detection". `TF-Blender <https://arxiv.org/pdf/2108.05821.pdf>`_.

Args:
num_convs (int): Number of embedding convs.
channels (int): Channels of embedding convs. Defaults to 256.
kernel_size (int): Kernel size of embedding convs, Defaults to 3.
norm_cfg (dict): Configuration of normlization method after each
conv. Defaults to None.
act_cfg (dict): Configuration of activation method after each
conv. Defaults to dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""


def __init__(self,
num_convs=1,
channels=256,
kernel_size=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')):
super(TFBlenderAggregator, self).__init__()
assert num_convs > 0, 'The number of convs must be bigger than 1.'
self.embed_convs = nn.ModuleList()
for i in range(num_convs):
if i == num_convs - 1:
new_norm_cfg = None
new_act_cfg = None
else:
new_norm_cfg = norm_cfg
new_act_cfg = act_cfg
self.embed_convs.append(
ConvModule(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
norm_cfg=new_norm_cfg,
act_cfg=new_act_cfg))

self.tf_blenders = nn.ModuleList()

new_norm_cfg = norm_cfg
new_act_cfg = act_cfg
self.tf_blenders.append(
ConvModule(
in_channels=channels * 8,
out_channels=channels * 4,
kernel_size=1,
padding=0,
norm_cfg=new_norm_cfg,
act_cfg=new_act_cfg))
self.tf_blenders.append(
ConvModule(
in_channels=channels * 4,
out_channels=channels * 2,
kernel_size=3,
padding=1,
norm_cfg=new_norm_cfg,
act_cfg=new_act_cfg))
self.tf_blenders.append(
ConvModule(
in_channels=channels * 2,
out_channels=channels,
kernel_size=1,
padding=0,
norm_cfg=None,
act_cfg=None))

def forward(self, x, ref_x):
"""Aggregate reference feature maps `ref_x`.

The aggregation mainly contains two steps:
1. Building an aggregated tensor from `x`, `x_embed` ,`ref_x`,
and 'ref_x_embed' of shape [N, C*8, H, W]
2. Compute weights through passing Temporal Relation, Feature Adjustment,
and Feature Blender modules.
3. Use the normlized (i.e. softmax) cos similarity to weightedly sum
`ref_x`.

Args:
x (Tensor): of shape [1, C, H, W]
ref_x (Tensor): of shape [N, C, H, W]. N is the number of reference
feature maps.

Returns:
Tensor: The aggregated feature map with shape [1, C, H, W].
"""
# assert len(x.shape) == 4 and len(x) == 1, \
# "Only support 'batch_size == 1' for x"
x_embed = x
for embed_conv in self.embed_convs:
x_embed = embed_conv(x_embed)
x_embed = x_embed / x_embed.norm(p=2, dim=1, keepdim=True)

ref_x_embed = ref_x
for embed_conv in self.embed_convs:
ref_x_embed = embed_conv(ref_x_embed)
ref_x_embed = ref_x_embed / ref_x_embed.norm(p=2, dim=1, keepdim=True)

tf_weight = torch.cat((x_embed.repeat(ref_x_embed.shape[0],1,1,1), \
ref_x_embed, \
x_embed.repeat(ref_x_embed.shape[0],1,1,1) - ref_x_embed, \
x.repeat(ref_x_embed.shape[0],1,1,1), \
ref_x, \
x.repeat(ref_x_embed.shape[0],1,1,1) - ref_x, \
- x_embed.repeat(ref_x_embed.shape[0],1,1,1) + ref_x_embed, \
- x.repeat(ref_x_embed.shape[0],1,1,1) + ref_x \
), dim=1)

for tf_blender in self.tf_blenders:
tf_weight = tf_blender(tf_weight)

ada_weights = tf_weight

ada_weights = ada_weights.softmax(dim=0)
agg_x = torch.sum(ref_x * ada_weights, dim=0, keepdim=True)
return agg_x