-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 03ad086
Showing
43 changed files
with
3,707 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
/X101-features | ||
/X152-features | ||
/m2_annotations | ||
evaluation/spice/* | ||
*.pyc | ||
*.jar | ||
/saved_transformer_models | ||
/tensorboard_logs | ||
/visualization | ||
/.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# $\mathcal{S}^2$ Transformer for Image Captioning | ||
|
||
[](https://www.python.org/) | ||
[](https://github.com/zchoi/S2-Transformer/blob/main/LICENSE) | ||
[](https://pytorch.org/) | ||
|
||
This repository contains the official code implementation for the paper [_S<sup>2</sup> Transformer for Image Captioning_ (IJCAI 2022)](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhang_RSTNet_Captioning_With_Adaptive_Attention_on_Visual_and_Non-Visual_Words_CVPR_2021_paper.pdf). | ||
|
||
<p align="center"> | ||
<img src="framework.png" alt="Relationship-Sensitive Transformer" width="850"/> | ||
</p> | ||
|
||
## Table of Contents | ||
- [Environment setup](#environment-setup) | ||
- [Data Preparation](#data-preparation) | ||
- [Training](#training) | ||
- [Evaluation](#evaluation) | ||
- [Reference and Citation](#reference-and-citation) | ||
- [Acknowledgements](#acknowledgements) | ||
|
||
## Environment setup | ||
|
||
Clone this repository and create the `m2release` conda environment using the `environment.yml` file: | ||
``` | ||
conda env create -f environment.yaml | ||
conda activate m2release | ||
``` | ||
|
||
Then download spacy data by executing the following command: | ||
``` | ||
python -m spacy download en_core_web_md | ||
``` | ||
|
||
**Note:** Python 3 is required to run our code. If you suffer network problems, please download ```en_core_web_md``` library from [here](https://drive.google.com/file/d/1jf6ecYDzIomaGt3HgOqO_7rEL6oiTjgN/view?usp=sharing), unzip and place it to ```/your/anaconda/path/envs/m2release/lib/python*/site-packages/``` | ||
|
||
|
||
## Data Preparation | ||
|
||
* **Annotation**. Download the annotation file [annotation.zip](https://drive.google.com/file/d/1Zc2P3-MIBg3JcHT1qKeYuQt9CnQcY5XJ/view?usp=sharing) [1]. Extract and put it in the project root directory. | ||
* **Feature**. Download processed image features [ResNeXt-101](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) and [ResNeXt-152](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) features [2], put it in the project root directory. | ||
<!-- * **Evaluation**. Download the evaluation tools [here](https://pan.baidu.com/s/1xVZO7t8k4H_l3aEyuA-KXQ). Acess code: jcj6. Extarct and put it in the project root directory. --> | ||
|
||
|
||
## Training | ||
Run `python train_transformer.py` using the following arguments: | ||
|
||
| Argument | Possible values | | ||
|------|------| | ||
| `--exp_name` | Experiment name| | ||
| `--batch_size` | Batch size (default: 50) | | ||
| `--workers` | Number of workers, accelerate model training in the xe stage.| | ||
| `--head` | Number of heads (default: 8) | | ||
| `--resume_last` | If used, the training will be resumed from the last checkpoint. | | ||
| `--resume_best` | If used, the training will be resumed from the best checkpoint. | | ||
| `--features_path` | Path to visual features file (h5py)| | ||
| `--annotation_folder` | Path to annotations | | ||
| `--num_clusters` | Number of pseudo regions | | ||
|
||
For example, to train the model, run the following command: | ||
``` | ||
python train_transformer.py --exp_name S2 --batch_size 50 --m 40 --head 8 --features_path /path/to/features --num_clusters 5 | ||
``` | ||
or just run: | ||
``` | ||
bash train.sh | ||
``` | ||
**Note:** We apply `torch.distributed` to train our model, you can set the `worldSize` in [train_transformer.py]() to determine the number of GPUs for your training. | ||
|
||
## Evaluation | ||
### Offline Evaluation. | ||
Run `python test_transformer.py` to evaluate the model using the following arguments: | ||
``` | ||
python test_transformer.py --batch_size 10 --features_path /path/to/features --model_path /path/to/saved_transformer_models/ckpt --num_clusters 5 | ||
``` | ||
|
||
**Note:** We have removed the ```SPICE``` evaluation metric during training because it is time-cost. You can add it when evaluate the model: download this [file](https://drive.google.com/file/d/1vEVsbEFjDstmSvoWhu4UdKaJjX1jJXpR/view?usp=sharing) and put it in ```/path/to/evaluation/```, then uncomment codes in [init代码](). | ||
|
||
We provide pretrained model [here](https://drive.google.com/file/d/1Y133r4Wd9ediS1Jqlwc1qtL15vCK_Mik/view?usp=sharing), you will get following results (second row) by evaluating the pretrained model: | ||
|
||
| Model | B@1 | B@4 | M | R | C | S | | ||
|:---------: |:-------: |:-: |:---------------: |:--------------------------: |:-------: | :-------:| | ||
| Our Paper (ResNext101) | 81.1 | 39.6 | 29.6 | 59.1 | 133.5 | 23.2| | ||
| Reproduced Model (ResNext101) | 81.2 | 39.9 | 29.6 | 59.1 | 133.7 | 23.3| | ||
|
||
|
||
|
||
### Online Evaluation | ||
We also report the performance of our model on the online COCO test server with an ensemble of four S<sup>2</sup> models. The detailed online test code can be obtained in this [repo](https://github.com/zhangxuying1004/RSTNet). | ||
|
||
## Reference and Citation | ||
### Reference | ||
[1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. | ||
[2] Xuying Zhang, Xiaoshuai Sun, Yunpeng Luo, Jiayi Ji, Yiyi Zhou, Yongjian Wu, Feiyue | ||
Huang, and Rongrong Ji. Rstnet: Captioning with adaptive attention on visual and non-visual words. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15465–15474, 2021. | ||
### Citation | ||
``` | ||
@inproceedings{S2, | ||
author = {Pengpeng Zeng and | ||
Haonan Zhang and | ||
Jingkuan Song and | ||
Lianli Gao}, | ||
title = {S2 Transformer for Image Captioning}, | ||
booktitle = {IJCAI}, | ||
% pages = {????--????} | ||
year = {2022} | ||
} | ||
``` | ||
## Acknowledgements | ||
Thanks Zhang _et.al_ for releasing the visual features (ResNeXt-101 and ResNeXt-152). Our code implementation is also based on their [repo](https://github.com/zhangxuying1004/RSTNet). | ||
Thanks for the original annotations prepared by [M<sup>2</sup> Transformer](https://github.com/aimagelab/meshed-memory-transformer), and effective visual representation from [grid-feats-vqa](https://github.com/facebookresearch/grid-feats-vqa). | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .field import RawField, Merge, ImageDetectionsField, TextField | ||
from .dataset import COCO | ||
from torch.utils.data import DataLoader as TorchDataLoader | ||
|
||
class DataLoader(TorchDataLoader): | ||
def __init__(self, dataset, *args, **kwargs): | ||
super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) |
Oops, something went wrong.