Skip to content

Commit 6572c25

Browse files
committed
Add Release 1
Date: 22/05/2017 Models: - MLBNoAtt - MLBAtt - MutanNoAtt - MutanAtt Datasets: - VQA-1 - COCO Misc: - SamplingAns logger Fix: Add tanh to mlb model Fix logger with visdom Improve cli options Add VQA eval Add save model&optim&logger Add mlb_noatt_torch7 pretrained Add arg help_opt to disp options Sync before TGV Sync from TGV Improve logger (remove visdom) Improve logger (remove visdom) Add visu matplotlib Add params_count Add MLB with attention Gitignore .ipython & .html Add MutanNoAtt and EmbeddingDropout Gitignore data & logs Remove external/logger Add MutanAtt + refactoring Improve assert message Add MutanAtt + refactoring Add MutanAtt + refactoring Add MutanNoAtt ported+tested Add mlb_att options Fix AttModel, Add testdev, Add porting, Add EmbDropout blocmutan Add blocmutan Add blocmutan_noatt Add blocmutan_noatt remove dicts in fusion.py to make blocmutan compatible with gpu make blocmutan_noatt compatible with variable size Add samplingans blocmutan_att Add samplingans to other options Add VQA2 Add save_ckpt argument Add fixed_emb Fix&Improve info/model/optim save get_sizes_list in BlocMutanFusion for variable_size Update/Add options (fixed_emb) Add blocm_chunks Add info to logger Add info to logger Refactor dataset & Add arg blocm_dim Add arg blocm_R Add VisualGenome (no tested) and few fix (tested) Add multians Fix VGenome Add samplingans Add list_names to extract Improve porting Improve visu Add skip-thoughts.torch as submodule Fix Improve (remake) ported models and gitignore Add README
0 parents  commit 6572c25

36 files changed

+3505
-0
lines changed

.gitignore

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
data/coco
2+
data/porting
3+
data/skip-thoughts
4+
data/vqa
5+
data/vqa2
6+
7+
logs/porting
8+
logs/vqa_old
9+
logs/vqa
10+
logs/vqa2
11+
12+
# Byte-compiled / optimized / DLL files
13+
__pycache__/
14+
*.py[cod]
15+
*$py.class
16+
17+
# Mac
18+
.DS_Store
19+
._.DS_Store
20+
21+
.ipynb_checkpoints
22+
*.html

.gitmodules

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[submodule "vqa/external/VQA"]
2+
path = vqa/external/VQA
3+
url = https://github.com/Cadene/VQA.git
4+
[submodule "vqa/external/skip-thoughts.torch"]
5+
path = vqa/external/skip-thoughts.torch
6+
url = https://github.com/Cadene/skip-thoughts.torch.git

README.md

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Visual Question Answering in pytorch
2+
3+
This repo was made by [Remi Cadene](http://remicadene.com) and [Hedi Ben-Younes](https://twitter.com/labegne), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr). We developped this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA-1 dataset](http://visualqa.org).
4+
5+
The goal of this repo is two folds:
6+
- to make it easier to reproduce our results,
7+
- to provide an efficient and modular code base to the community for further research on other VQA datasets.
8+
9+
If you have any questions about our code or model, don't hesitate to contact us or to submit any issues. Pull request are welcome!
10+
11+
## Introduction
12+
13+
### What is the task about?
14+
15+
The task is about training models in a end-to-end fashion on a multimodal dataset made of triplets:
16+
17+
- an **image** with no other information than the raw pixels,
18+
- a **question** about visual content(s) on the associated image,
19+
- a short **answer** to the question (one or a few words).
20+
21+
As you can see in the illustration bellow, two different triplets (but same image) of the VQA dataset are represented. The models need to learn rich multimodal representations to be able to give the right answers.
22+
23+
<p align="center">
24+
<img src="https://raw.githubusercontent.com/Cadene/vqa.pytorch/master/doc/vqa_task.png" width="600"/>
25+
</p>
26+
27+
The VQA task is still on active research. However, when it will be solved, it could be very useful to improve human-to-machine interfaces (especially for the blinds).
28+
29+
### Quick insight about our method
30+
31+
The VQA community developped an approach based on four learnable components:
32+
33+
- a question model which can be a LSTM, GRU, or pretrained Skipthoughts,
34+
- an image model which can be a pretrained VGG16 or Resnet-152,
35+
- a fusion scheme which can be an element-wise sum, concatenation, [MCB](https://arxiv.org/abs/1606.01847), [MLB](https://arxiv.org/abs/1610.04325), or [Mutan](https://arxiv.org/abs/1705.06676),
36+
- optionally, an attention scheme which may have several "glimpses".
37+
38+
<p align="center">
39+
<img src="https://raw.githubusercontent.com/Cadene/vqa.pytorch/master/doc/mutan.png" width="400"/>
40+
</p>
41+
42+
One of our claim is that the multimodal fusion between the image and the question representations is a critical component. Thus, our proposed model uses a Tucker Decomposition of the correlation Tensor to model reacher multimodal interactions in order to provide proper answers. Our best model is based on :
43+
44+
- a pretrained Skipthoughts for the question model,
45+
- features from a pretrained Resnet-152 (with images of size 3x448x448) for the image model,
46+
- our proposed Mutan (based on a Tucker Decomposition) for the fusion scheme,
47+
- an attention scheme with two "glimpses".
48+
49+
## Using this code
50+
51+
### Requirements
52+
53+
#### Installation
54+
55+
First install python 3 (we don't provide support for python 2). We advise you to install python 3 and pytorch with Anaconda:
56+
57+
- [python with anaconda](https://www.continuum.io/downloads)
58+
- [pytorch with CUDA](http://pytorch.org)
59+
60+
```
61+
conda create --name vqa python=3
62+
source activate vqa
63+
conda install pytorch torchvision cuda80 -c soumith
64+
```
65+
66+
Then clone the repo (with the `--recursive` flag for submodules) and install the complementary requirements:
67+
68+
```
69+
cd $HOME
70+
git clone --recursive https://github.com/Cadene/vqa.pytorch.git
71+
cd vqa.pytorch
72+
pip install -r requirements.txt
73+
```
74+
75+
#### Submodules
76+
77+
Our code has two external dependencies:
78+
79+
- [VQA](https://github.com/Cadene/VQA) is used to evaluate results files on the valset with the OpendEnded accuracy,
80+
- [skip-thoughts.torch](https://github.com/Cadene/skip-thoughts.torch) is used to import pretrained GRUs and embeddings.
81+
82+
#### Data
83+
84+
Data will be automaticaly downloaded and preprocessed when needed. Links to data are stored in `vqa/datasets/vqa.py` and `vqa/datasets/coco.py`.
85+
86+
87+
### Reproducing results
88+
89+
#### Features
90+
91+
As we first developped on Lua/Torch7, we used the features of [Resnet-152 pretrained with Torch7](https://github.com/facebook/fb.resnet.torch). We will make the features available for download in few days.
92+
93+
/!\ Notice that we've tried the features of [Resnet-152 pretrained with pytorch](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) and got lower results.
94+
95+
#### Pretrained models
96+
97+
We currently provide three models trained with our old Torch7 code and ported to Pytorch:
98+
99+
- MutanNoAtt trained on the VQA-1 trainset,
100+
- MLBAtt trained on the VQA-1 trainvalset and VisualGenome,
101+
- MutanAtt trained on the VQA-1 trainvalset and VisualGenome.
102+
103+
```
104+
mkdir -p logs/vqa
105+
cd logs/vqa
106+
wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_noatt_train.zip
107+
wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mlb_att_trainval.zip
108+
wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_att_trainval.zip
109+
```
110+
111+
Even if we provide results files associated to our pretrained models, you can evaluate them once again on the valset, testset and testdevset using a single command:
112+
113+
```
114+
python train.py -e --path_opt options/vqa/mutan_noatt_train.yaml --resume ckpt
115+
python train.py -e --path_opt options/vqa/mlb_noatt_trainval.yaml --resume ckpt
116+
python train.py -e --path_opt options/vqa/mutan_att_trainval.yaml --resume ckpt
117+
```
118+
119+
To obtain test and testdev results, you will need to zip your result json file (name it as `results.zip`) and to submit it on the [evaluation server](https://competitions.codalab.org/competitions/6961).
120+
121+
### Documentation
122+
123+
#### Architecture
124+
125+
```
126+
.
127+
├── options # default options dir containing yaml files
128+
├── logs # experiments dir containing directories of logs (one by experiment)
129+
├── data # datasets directories
130+
| ├── coco # images and features
131+
| ├── vqa # raw, interim and processed data
132+
| └── ...
133+
├── vqa # vqa package dir
134+
| ├── datasets # datasets classes & functions dir (vqa, coco, images, features, etc.)
135+
| ├── external # submodules dir (VQA, skip-thoughts.torch)
136+
| ├── lib # misc classes & func dir (engine, logger, dataloader, etc.)
137+
| └── models # models classes & func dir (att, fusion, notatt, seq2vec)
138+
|
139+
├── train.py # train & eval models
140+
├── eval_res.py # eval results files with OpenEnded metric
141+
├── extract.pt # extract features from coco with CNNs
142+
└── visu.ipynb # visualizing logs (under development)
143+
```
144+
145+
#### Options
146+
147+
There are three kind of options:
148+
149+
- options from the yaml options files stored in the `options` directory which are used as default (path to directory, logs, model, features, etc.)
150+
- options from the ArgumentParser in the `train.py` file which are set to None and can overwrite default options (learning rate, batch size, etc.)
151+
- options from the ArgumentParser in the `train.py` file which are set to default values (print frequency, number of threads, resume model, evaluate model, etc.)
152+
153+
You can easly add new options in your custom yaml file if needed. Also, if you want to grid search a parameter, you can add an ArgumentParser option and modify the dictionnary in `train.py:L80`.
154+
155+
#### Datasets
156+
157+
We currently provide three datasets:
158+
159+
- [COCOImages](http://mscoco.org/) currently used to extract features, it comes with three datasets: trainset, valset and testset
160+
- COCOFeatures used by any VQA datasets
161+
- [VQA](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
162+
163+
We plan to add:
164+
165+
- [VisualGenome](http://visualgenome.org/)
166+
- [VQA2](http://www.visualqa.org/)
167+
- [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/)
168+
169+
#### Models
170+
171+
We currently provide four models:
172+
173+
- MLBNoAtt: a strong baseline (BayesianGRU + Element-wise product)
174+
- [MLBAtt](https://arxiv.org/abs/1610.04325): the previous state-of-the-art which add an attention strategy
175+
- MutanNoAtt: our proof of concept (BayesianGRU + Mutan Fusion)
176+
- MutanAtt: the current state-of-the-art
177+
178+
We plan to add several other strategies in the futur.
179+
180+
### Quick examples
181+
182+
#### Extract features from COCO
183+
184+
The needed images will be automaticaly downloaded to `dir_data` and the features will be extracted with a resnet152 by default.
185+
186+
There are three options for `mode` :
187+
188+
- `att`: features will be of size 2048x14x14,
189+
- `noatt`: features will be of size 2048,
190+
- `both`: default option.
191+
192+
Beware, you will need some space on your SSD:
193+
194+
- 32GB for the images,
195+
- 125GB for the train features,
196+
- 123GB for the test features,
197+
- 61GB for the val features.
198+
199+
```
200+
python extract.py -h
201+
python extract.py --dir_data data/coco --data_split train
202+
python extract.py --dir_data data/coco --data_split val
203+
python extract.py --dir_data data/coco --data_split test
204+
```
205+
206+
Note: By default our code will share computations over all available GPUs. If you want to select only one or a few, use the following prefix:
207+
208+
```
209+
CUDA_VISIBLE_DEVICES=0 python extract.py
210+
CUDA_VISIBLE_DEVICES=1,2 python extract.py
211+
```
212+
213+
#### Train models on VQA
214+
215+
Display help message, selected options and run default. The needed data will be automaticaly downloaded and processed using the options in `options/default.yaml`.
216+
217+
```
218+
python train.py -h
219+
python train.py --help_opt
220+
python train.py
221+
```
222+
223+
Run a MutanNoAtt model with default options.
224+
225+
```
226+
python train.py --path_opt options/mutan_noatt.yaml --dir_logs logs/mutan_noatt
227+
```
228+
229+
Run a MutanAtt model on the trainset and evaluate on the valset after each epoch.
230+
231+
```
232+
python train.py --vqa_trainsplit train --path_opt options/mutan_att.yaml
233+
```
234+
235+
Run a MutanAtt model on the trainset and valset (by default) and run throw the testset after each epoch (produce a results file that you can submit to the evaluation server).
236+
237+
```
238+
python train.py --vqa_trainsplit trainval --path_opt options/mutan_att.yaml
239+
```
240+
241+
242+
#### Restart training
243+
244+
Restart the model from the last checkpoint.
245+
246+
```
247+
python train.py --path_opt options/mutan_noatt.yaml --dir_logs logs/mutan_noatt --resume ckpt
248+
```
249+
250+
Restart the model from the best checkpoint.
251+
252+
```
253+
python train.py --path_opt options/mutan_noatt.yaml --dir_logs logs/mutan_noatt --resume best
254+
```
255+
256+
#### Evaluate models on VQA
257+
258+
Evaluate the model from the best checkpoint. If your model has been trained on the training set only (`vqa_trainsplit=train`), the model will be evaluate on the valset and will run throw the testset. If it was trained on the trainset + valset (`vqa_trainsplit=trainval`), it will not be evaluate on the valset.
259+
260+
```
261+
python train.py --vqa_trainsplit train --path_opt options/mutan_att.yaml --dir_logs logs/mutan_att --resume best -e
262+
```
263+
264+
## Acknowledgment
265+
266+
Special thanks to the authors of [MLB](https://arxiv.org/abs/1610.04325) for providing some [Torch7 code](https://github.com/jnhwkim/MulLowBiVQA), [MCB](https://arxiv.org/abs/1606.01847) for providing some [Caffe code](https://github.com/akirafukui/vqa-mcb), and our professors and friends from LIP6 for the perfect working atmosphere.

data/.keep

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.empty

eval_res.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import argparse
2+
import json
3+
import random
4+
import os
5+
from os.path import join
6+
import sys
7+
#import pickle
8+
helperDir = 'vqa/external/VQA/'
9+
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(helperDir))
10+
sys.path.insert(0, '%s/PythonEvaluationTools/vqaEvaluation' %(helperDir))
11+
from vqa import VQA
12+
from vqaEval import VQAEval
13+
14+
15+
if __name__=="__main__":
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--dir_vqa', type=str, default='/local/cadene/data/vqa')
18+
parser.add_argument('--dir_epoch', type=str, default='logs/16_12_13_20:39:55/epoch,1')
19+
parser.add_argument('--subtype', type=str, default='train2014')
20+
args = parser.parse_args()
21+
22+
diranno = join(args.dir_vqa, 'raw', 'annotations')
23+
annFile = join(diranno, 'mscoco_%s_annotations.json' % (args.subtype))
24+
quesFile = join(diranno, 'OpenEnded_mscoco_%s_questions.json' % (args.subtype))
25+
vqa = VQA(annFile, quesFile)
26+
27+
taskType = 'OpenEnded'
28+
dataType = 'mscoco'
29+
dataSubType = args.subtype
30+
resultType = 'model'
31+
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
32+
33+
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = \
34+
['%s/%s_%s_%s_%s_%s.json' % (args.dir_epoch, taskType, dataType,
35+
dataSubType, resultType, fileType) for fileType in fileTypes]
36+
vqaRes = vqa.loadRes(resFile, quesFile)
37+
vqaEval = VQAEval(vqa, vqaRes, n=2)
38+
39+
quesIds = [int(d['question_id']) for d in json.loads(open(resFile).read())]
40+
vqaEval.evaluate(quesIds=quesIds)
41+
42+
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
43+
#json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
44+
#json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
45+
#json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))

0 commit comments

Comments
 (0)