Skip to content

Commit

Permalink
add trainning code and pre-trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
yelantf committed Sep 24, 2020
1 parent 2cab9fa commit bcb6b16
Show file tree
Hide file tree
Showing 116 changed files with 2,221 additions and 326 deletions.
72 changes: 72 additions & 0 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
## Getting Started with AlphAction

The hyper-parameters of each experiment are controlled by
a .yaml config file, which is located in the directory
`config_files`. All of these configuration files assume
that we are running on 8 GPUs. We need to create a symbolic
link to the directory `output`, where the output (logs and checkpoints)
will be saved. Besides, we recommend to create a directory `models` to place
model weights. These can be done with following commands.

```shell
mkdir -p /path/to/output
ln -s /path/to/output data/output
mkdir -p /path/to/models
ln -s /path/to/models data/models
```

### Training

Download pre-trained models from [MODEL_ZOO.md](MODEL_ZOO.md#pre-trained-models).
Then place pre-trained models in `data/models` directory with following structure:

```
models/
|_ pretrained_models/
| |_ SlowFast-ResNet50-4x16.pth
| |_ SlowFast-ResNet101-8x8.pth
```

To train on a single GPU, you only need to run following command. The
argument `--use-tfboard` enables tensorboard to log training process.
Because the config files assume that we are using 8 GPUs, the global
batch size `SOLVER.VIDEOS_PER_BATCH` and `TEST.VIDEOS_PER_BATCH` can
be too large for a single GPU. Therefore, in the following command, we
modify the batch size and also adjust the learning rate and schedule
length according to the linear scaling rule.

```shell
python train_net.py --config-file "path/to/config/file.yaml" \
--transfer --no-head --use-tfboard \
SOLVER.BASE_LR 0.000125 \
SOLVER.STEPS '(560000, 720000)' \
SOLVER.MAX_ITER 880000 \
SOLVER.VIDEOS_PER_BATCH 2 \
TEST.VIDEOS_PER_BATCH 2
```

We use the launch utility `torch.distributed.launch` to launch multiple
processes for distributed training on multiple gpus. `GPU_NUM` should be
replaced by the number of gpus to use. Hyper-parameters in the config file
can still be modified in the way used in single-GPU training.

```shell
python -m torch.distributed.launch --nproc_per_node=GPU_NUM \
train_net.py --config-file "path/to/config/file.yaml" \
--transfer --no-head --use-tfboard
```

### Inference

To do inference on multiple GPUs, you should run the following command. Note that
our code first trys to load the `last_checkpoint` in the `OUTPUT_DIR`. If there
is no such file in `OUTPUT_DIR`, it will then load the model from the
path specified in `MODEL.WEIGHT`. To use `MODEL.WEIGHT` to do the inference,
you need to ensure that there is no `last_checkpoint` in `OUTPUT_DIR`.
You can download the model weights from [MODEL_ZOO.md](MODEL_ZOO.md#ava-models).

```shell
python -m torch.distributed.launch --nproc_per_node=GPU_NUM \
test_net.py --config-file "path/to/config/file.yaml" \
MODEL.WEIGHT "path/to/model/weight"
```
22 changes: 9 additions & 13 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,18 @@ We recommend to setup the environment with Anaconda,
the step-by-step installation script is shown below.

```bash
conda create -n action_det python=3.7
conda activate action_det
conda create -n alphaction python=3.7
conda activate alphaction

conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
conda install av -c conda-forge
conda install tqdm

pip install yacs opencv-python tensorboardX
# install pytorch with the same cuda version as in your environment
cuda_version=$(nvcc --version | grep -oP '(?<=release )[\d\.]*?(?=,)')
conda install pytorch torchvision cudatoolkit=$cuda_version -c pytorch

######################
# dependicies for demo
conda install Cython SciPy matplotlib
pip install cython-bbox easydict
######################
conda install av -c conda-forge
conda install cython

git clone https://github.com/MVIG-SJTU/AlphAction.git
cd AlphAction
python setup.py build develop
pip install -e . # Other dependicies will be installed here

```
22 changes: 22 additions & 0 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## AlphAction Model Zoo

### Pre-trained Models

We provide backbone models pre-trained on Kinetics dataset, used for further
fine-tuning on AVA dataset. The reported accuracy are obtained by 30-view testing.

| backbone | pre-train | frame length | sample rate | top-1 | top-5 | model |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
| SlowFast-R50 | Kinetics-700 | 4 | 16 | 66.34 | 86.66 | [[link]](https://drive.google.com/file/d/1bNcF295jxY4Zbqf0mdtsw9QifpXnvOyh/view?usp=sharing) |
| SlowFast-R101 | Kinetics-700 | 8 | 8 | 69.32 | 88.84 | [[link]](https://drive.google.com/file/d/1v1FdPUXBNRj-oKfctScT4L4qk8L1k3Gg/view?usp=sharing) |

### AVA Models

| config | backbone | IA structure | mAP | in paper | model |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
| [resnet50_4x16f_baseline](config_files/resnet50_4x16f_baseline.yaml) | SlowFast-R50-4x16 | w/o | 26.7 | 26.5 | [[link]](https://drive.google.com/file/d/1_yAxk6R58Dn6IjBCx-WBwCQEf_582Vuv/view?usp=sharing) |
| [resnet50_4x16f_parallel](config_files/resnet50_4x16f_parallel.yaml) | SlowFast-R50-4x16 | Parallel | 29.0 | 28.9 | [[link]](https://drive.google.com/file/d/13iDNnkxjDqo8OuEhnHFe3P-fERHTbFaD/view?usp=sharing) |
| [resnet50_4x16f_serial](config_files/resnet50_4x16f_serial.yaml) | SlowFast-R50-4x16 | Serial | 29.8 | 29.6 | [[link]](https://drive.google.com/file/d/1S6NIPQ8NoZpzOKkHjzdpFVOtsU6GjqIv/view?usp=sharing) |
| [resnet50_4x16f_denseserial](config_files/resnet50_4x16f_denseserial.yaml) | SlowFast-R50-4x16 | Dense Serial | 30.0 | 29.8 | [[link]](https://drive.google.com/file/d/1OZmlA6V6XoWEA_usyijUREOYujzYL_kP/view?usp=sharing) |
| [resnet101_8x8f_baseline](config_files/resnet101_8x8f_baseline.yaml) | SlowFast-R101-8x8 | w/o | 29.3 | 29.3 | [[link]](https://drive.google.com/file/d/1GC56oNEX00oEH8aiGYdFMKENdWf2VvAY/view?usp=sharing) |
| [resnet101_8x8f_denseserial](config_files/resnet101_8x8f_denseserial.yaml) | SlowFast-R101-8x8 | Dense Serial | 32.4 | 32.3 | [[link]](https://drive.google.com/file/d/1DKHo0XoBjrTO2fHTToxbV0mAPzgmNH3x/view?usp=sharing) |
77 changes: 16 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,48 @@ AlphAction aims to detect the actions of multiple persons in videos. It is
model on AVA dataset.**

This project is the official implementation of paper
[Asynchronous Interaction Aggregation for Action Detection](https://arxiv.org/abs/2004.07485), authored
[Asynchronous Interaction Aggregation for Action Detection](https://arxiv.org/abs/2004.07485) (**ECCV 2020**), authored
by Jiajun Tang*, Jin Xia* (equal contribution), Xinzhi Mu, [Bo Pang](https://bopang1996.github.io/),
[Cewu Lu](http://mvig.sjtu.edu.cn/) (corresponding author). It is now accepted by **ECCV 2020**!
[Cewu Lu](http://mvig.sjtu.edu.cn/) (corresponding author).

<br/>
<div align="center">
<img src="gifs/demo1.gif" height=320 alt="demo1">
<img src="gifs/demo2.gif" height=320 alt="demo2">
<img src="https://user-images.githubusercontent.com/22748802/94115535-71fc9580-fe7c-11ea-98af-d8e9a8a2de82.gif" width=416 alt="demo1">
<img src="https://user-images.githubusercontent.com/22748802/94115605-8ccf0a00-fe7c-11ea-8855-ab84232612a0.gif" width=416 alt="demo2">
</div>
<div align="center">
<img src="gifs/demo3.gif" width=836 alt="demo3">
<img src="https://user-images.githubusercontent.com/22748802/94115715-b12ae680-fe7c-11ea-8180-8e3d7f57a4bb.gif" width=836 alt="demo3">
</div>
<br/>

## Demo Video

[![AlphAction demo video](https://user-images.githubusercontent.com/22748802/94115680-a83a1500-fe7c-11ea-878c-536db277fba7.jpg)](https://www.youtube.com/watch?v=TdGmbOJ9hoE "AlphAction demo video")
[[YouTube]](https://www.youtube.com/watch?v=TdGmbOJ9hoE) [[BiliBili]](https://www.bilibili.com/video/BV14A411J7Xv)

## Installation

You need first to install this project, please check [INSTALL.md](INSTALL.md)


## Data Preparation

To do training or inference on AVA dataset, please check [DATA.md](DATA.md)
for data preparation instructions.

## Model Zoo

| config | backbone | structure | mAP | in paper | model |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
| [resnet50_4x16f_parallel](config_files/resnet50_4x16f_parallel.yaml) | ResNet-50 | Parallel | 29.0 | 28.9 | [[link]](https://drive.google.com/open?id=13iDNnkxjDqo8OuEhnHFe3P-fERHTbFaD) |
| [resnet50_4x16f_serial](config_files/resnet50_4x16f_serial.yaml) | ResNet-50 | Serial | 29.8 | 29.6 | [[link]](https://drive.google.com/open?id=1S6NIPQ8NoZpzOKkHjzdpFVOtsU6GjqIv) |
| [resnet50_4x16f_denseserial](config_files/resnet50_4x16f_denseserial.yaml) | ResNet-50 | Dense Serial | 30.0 | 29.8 | [[link]](https://drive.google.com/open?id=1OZmlA6V6XoWEA_usyijUREOYujzYL_kP) |
| [resnet101_8x8f_denseserial](config_files/resnet101_8x8f_denseserial.yaml) | ResNet-101 | Dense Serial | 32.4 | 32.3 | [[link]](https://drive.google.com/open?id=1DKHo0XoBjrTO2fHTToxbV0mAPzgmNH3x) |
Please see [MODEL_ZOO.md](MODEL_ZOO.md) for downloading models.

## Training and Inference

To do training or inference with AlphAction, please refer to [GETTING_STARTED.md](GETTING_STARTED.md).

## Visual Demo
## Demo Program

To run the demo program on video or webcam, please check the folder [demo](demo).
We select 15 common categories from the 80 action categories of AVA, and
provide a practical model which achieves high accuracy (about 70 mAP) on these categories.

## Training and Inference

The hyper-parameters of each experiment are controlled by
a .yaml config file, which is located in the directory
`config_files`. All of these configuration files assume
that we are running on 8 GPUs. We need to create a symbolic
link to the directory `output`, where the output (logs and checkpoints)
will be saved. Besides, we recommend to create a directory `models` to place
model weights. These can be done with following commands.

```shell
mkdir -p /path/to/output
ln -s /path/to/output data/output
mkdir -p /path/to/models
ln -s /path/to/models data/models
```

### Training

The pre-trained model weights and the training code will be public
available later. :wink:

### Inference

First, you need to download the model weights from [Model Zoo](#model-zoo).

To do inference on single GPU, you only need to run the following command.
It will load the model from the path speicified in `MODEL.WEIGHT`.
Note that the config `VIDEOS_PER_BATCH` is a global config, if you face
OOM error, you could overwrite the config in the command line as we do
in below command.
```shell
python test_net.py --config-file "path/to/config/file.yaml" \
MODEL.WEIGHT "path/to/model/weight" \
TEST.VIDEOS_PER_BATCH 4
```

We use the launch utility `torch.distributed.launch` to launch multiple
processes for inference on multiple GPUs. `GPU_NUM` should be
replaced by the number of gpus to use. Hyper-parameters in the config file
can still be modified in the way used in single-GPU inference.

```shell
python -m torch.distributed.launch --nproc_per_node=GPU_NUM \
test_net.py --config-file "path/to/config/file.yaml" \
MODEL.WEIGHT "path/to/model/weight"
```

## Acknowledgement
We thankfully acknowledge the computing resource support of Huawei Corporation
for this project.
Expand Down
File renamed without changes.
File renamed without changes.
93 changes: 73 additions & 20 deletions AlphAction/config/defaults.py → alphaction/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
_C.MODEL.ROI_ACTION_HEAD.NUM_OBJECT_MANIPULATION_CLASSES = 49
_C.MODEL.ROI_ACTION_HEAD.NUM_PERSON_INTERACTION_CLASSES = 17

_C.MODEL.ROI_ACTION_HEAD.POSE_LOSS_WEIGHT = 1.2
_C.MODEL.ROI_ACTION_HEAD.OBJECT_LOSS_WEIGHT = float(_C.MODEL.ROI_ACTION_HEAD.NUM_OBJECT_MANIPULATION_CLASSES)
_C.MODEL.ROI_ACTION_HEAD.PERSON_LOSS_WEIGHT = float(_C.MODEL.ROI_ACTION_HEAD.NUM_PERSON_INTERACTION_CLASSES)

# Focal loss config.
_C.MODEL.ROI_ACTION_HEAD.FOCAL_LOSS = CN()
_C.MODEL.ROI_ACTION_HEAD.FOCAL_LOSS.GAMMA = 2.0
_C.MODEL.ROI_ACTION_HEAD.FOCAL_LOSS.ALPHA = -1.

# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
Expand All @@ -63,6 +72,12 @@
_C.INPUT.FRAME_SAMPLE_RATE = 2
_C.INPUT.TAU = 8
_C.INPUT.ALPHA = 8
_C.INPUT.SLOW_JITTER = False

_C.INPUT.COLOR_JITTER = False
_C.INPUT.HUE_JITTER = 20.0 #in degree, hue is in 0~360
_C.INPUT.SAT_JITTER = 0.1
_C.INPUT.VAL_JITTER = 0.1

# -----------------------------------------------------------------------------
# Dataset
Expand Down Expand Up @@ -148,26 +163,64 @@
_C.MODEL.NONLOCAL.BN_INIT_GAMMA = 0.0


_C.IA_STRUCTURE = CN()
_C.IA_STRUCTURE.ACTIVE = False
_C.IA_STRUCTURE.STRUCTURE = "serial"
_C.IA_STRUCTURE.MAX_PER_SEC = 5
_C.IA_STRUCTURE.MAX_PERSON = 25
_C.IA_STRUCTURE.DIM_IN = 2304
_C.IA_STRUCTURE.DIM_INNER = 512
_C.IA_STRUCTURE.DIM_OUT = 512
_C.IA_STRUCTURE.LENGTH = (30, 30)
_C.IA_STRUCTURE.MEMORY_RATE = 1
_C.IA_STRUCTURE.FUSION = "concat"
_C.IA_STRUCTURE.CONV_INIT_STD = 0.01
_C.IA_STRUCTURE.DROPOUT = 0.
_C.IA_STRUCTURE.NO_BIAS = False
_C.IA_STRUCTURE.I_BLOCK_LIST = ['P', 'O', 'M', 'P', 'O', 'M']
_C.IA_STRUCTURE.LAYER_NORM = False
_C.IA_STRUCTURE.TEMPORAL_POSITION = True
_C.IA_STRUCTURE.ROI_DIM_REDUCE = True
_C.IA_STRUCTURE.USE_ZERO_INIT_CONV = True
_C.IA_STRUCTURE.MAX_OBJECT = 0
_C.MODEL.IA_STRUCTURE = CN()
_C.MODEL.IA_STRUCTURE.ACTIVE = False
_C.MODEL.IA_STRUCTURE.STRUCTURE = "serial"
_C.MODEL.IA_STRUCTURE.MAX_PER_SEC = 5
_C.MODEL.IA_STRUCTURE.MAX_PERSON = 25
_C.MODEL.IA_STRUCTURE.DIM_IN = 2304
_C.MODEL.IA_STRUCTURE.DIM_INNER = 512
_C.MODEL.IA_STRUCTURE.DIM_OUT = 512
_C.MODEL.IA_STRUCTURE.PENALTY = True
_C.MODEL.IA_STRUCTURE.LENGTH = (30, 30)
_C.MODEL.IA_STRUCTURE.MEMORY_RATE = 1
_C.MODEL.IA_STRUCTURE.FUSION = "concat"
_C.MODEL.IA_STRUCTURE.CONV_INIT_STD = 0.01
_C.MODEL.IA_STRUCTURE.DROPOUT = 0.
_C.MODEL.IA_STRUCTURE.NO_BIAS = False
_C.MODEL.IA_STRUCTURE.I_BLOCK_LIST = ['P', 'O', 'M', 'P', 'O', 'M']
_C.MODEL.IA_STRUCTURE.LAYER_NORM = False
_C.MODEL.IA_STRUCTURE.TEMPORAL_POSITION = True
_C.MODEL.IA_STRUCTURE.ROI_DIM_REDUCE = True
_C.MODEL.IA_STRUCTURE.USE_ZERO_INIT_CONV = True
_C.MODEL.IA_STRUCTURE.MAX_OBJECT = 0

# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.MAX_ITER = 75000

_C.SOLVER.BASE_LR = 0.02
_C.SOLVER.BIAS_LR_FACTOR = 2
_C.SOLVER.IA_LR_FACTOR = 1.0

_C.SOLVER.MOMENTUM = 0.9

_C.SOLVER.WEIGHT_DECAY = 0.0001
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0

# Use for bn
_C.SOLVER.WEIGHT_DECAY_BN = 0.0

_C.SOLVER.SCHEDULER = "warmup_multi_step"

_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = (33750, 67500)

_C.SOLVER.WARMUP_ON = True
_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
_C.SOLVER.WARMUP_ITERS = 500
_C.SOLVER.WARMUP_METHOD = "linear"

_C.SOLVER.CHECKPOINT_PERIOD = 5000
_C.SOLVER.EVAL_PERIOD = 5000

# Number of video clips per batch
# This is global, so if we have 8 GPUs and VIDEOS_PER_BATCH = 16, each GPU will
# see 2 clips per batch
_C.SOLVER.VIDEOS_PER_BATCH = 32


# ---------------------------------------------------------------------------- #
# Specific test options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ class DatasetCatalog(object):
"video_root": "AVA/clips/trainval",
"ann_file": "AVA/annotations/ava_train_v2.2_min.json",
"box_file": "",
"eval_file_paths": {},
"eval_file_paths": {
"csv_gt_file": "AVA/annotations/ava_train_v2.2.csv",
"labelmap_file": "AVA/annotations/ava_action_list_v2.2_for_activitynet_2019.pbtxt",
"exclusion_file": "AVA/annotations/ava_train_excluded_timestamps_v2.2.csv",
},
"object_file": "AVA/boxes/ava_train_det_object_bbox.json",
},
"ava_video_val_v2.2": {
Expand Down
4 changes: 2 additions & 2 deletions csrc/ROIAlign3d.h → alphaction/csrc/ROIAlign3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ at::Tensor ROIAlign3d_forward(const at::Tensor& input,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.type().is_cuda()) {
if (input.is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign3d_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
#else
Expand All @@ -34,7 +34,7 @@ at::Tensor ROIAlign3d_backward(const at::Tensor& grad,
const int height,
const int width,
const int sampling_ratio) {
if (grad.type().is_cuda()) {
if (grad.is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign3d_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, length, height, width, sampling_ratio);
#else
Expand Down
Loading

0 comments on commit bcb6b16

Please sign in to comment.