Skip to content

Commit 504bdde

Browse files
committed
fix os.execute, fix hacks for data_parallel, fix COCO image loading, fix yaml missing size
1 parent a726a2c commit 504bdde

File tree

5 files changed

+44
-27
lines changed

5 files changed

+44
-27
lines changed

extract.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
help='dir dataset to download or/and load images')
2727
parser.add_argument('--data_split', default='train', type=str,
2828
help='Options: (default) train | val | test')
29-
parser.add_argument('--arch', '-a', default='resnet152',
29+
parser.add_argument('--arch', '-a', default='fbresnet152',
3030
choices=convnets.model_names,
3131
help='model architecture: ' +
3232
' | '.join(convnets.model_names) +
3333
' (default: fbresnet152)')
34-
parser.add_argument('--workers', default=4, type=int,
34+
parser.add_argument('--workers', default=4, type=int,
3535
help='number of data loading workers (default: 4)')
36-
parser.add_argument('--batch_size', '-b', default=80, type=int,
36+
parser.add_argument('--batch_size', '-b', default=80, type=int,
3737
help='mini-batch size (default: 80)')
3838
parser.add_argument('--mode', default='both', type=str,
3939
help='Options: att | noatt | (default) both')
@@ -56,7 +56,7 @@ def main():
5656
if args.dataset == 'coco':
5757
if 'coco' not in args.dir_data:
5858
raise ValueError('"coco" string not in dir_data')
59-
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
59+
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
6060
transform=transforms.Compose([
6161
transforms.Scale(args.size),
6262
transforms.CenterCrop(args.size),
@@ -68,7 +68,7 @@ def main():
6868
raise ValueError('train split is required for vgenome')
6969
if 'vgenome' not in args.dir_data:
7070
raise ValueError('"vgenome" string not in dir_data')
71-
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
71+
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
7272
transform=transforms.Compose([
7373
transforms.Scale(args.size),
7474
transforms.CenterCrop(args.size),
@@ -122,7 +122,7 @@ def extract(data_loader, model, path_file, mode):
122122

123123
nb_regions = output_att.size(2) * output_att.size(3)
124124
output_noatt = output_att.sum(3).sum(2).div(nb_regions).view(-1, 2048)
125-
125+
126126
batch_size = output_att.size(0)
127127
if mode == 'both' or mode == 'att':
128128
hdf5_att[idx:idx+batch_size] = output_att.data.cpu().numpy()
@@ -141,7 +141,7 @@ def extract(data_loader, model, path_file, mode):
141141
i, len(data_loader),
142142
batch_time=batch_time,
143143
data_time=data_time,))
144-
144+
145145
hdf5_file.close()
146146

147147
# Saving image names in the same order than extraction
@@ -154,4 +154,4 @@ def extract(data_loader, model, path_file, mode):
154154

155155

156156
if __name__ == '__main__':
157-
main()
157+
main()

options/vqa/mutan_att_trainval.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ vqa:
1212
samplingans: True
1313
coco:
1414
dir: data/coco
15-
arch: fbresnet152torch
15+
arch: fbresnet152
1616
mode: att
17+
size: 448
1718
model:
1819
arch: MutanAtt
1920
dim_v: 2048

vqa/datasets/coco.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ class COCOImages(AbstractImagesDataset):
2323
def __init__(self, data_split, opt, transform=None, loader=default_loader):
2424
self.split_name = split_name(data_split)
2525
super(COCOImages, self).__init__(data_split, opt, transform, loader)
26-
self.dir_split = os.path.join(self.dir_raw, self.split_name)
26+
self.dir_split = self.get_dir_data()
2727
self.dataset = ImagesFolder(self.dir_split, transform=self.transform, loader=self.loader)
2828
self.name_to_index = self._load_name_to_index()
2929

30+
def get_dir_data(self):
31+
return os.path.join(self.dir_raw, self.split_name)
32+
3033
def _raw(self):
3134
if self.data_split in ['train', 'val']:
3235
os.system('wget http://msvocds.blob.core.windows.net/coco2014/{}.zip -P {}'.format(self.split_name, self.dir_raw))
3336
elif self.data_split == 'test':
34-
os.execute('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw)
37+
os.system('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw)
3538
else:
3639
assert False, 'data_split {} not exists'.format(self.data_split)
37-
os.execute('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw)
40+
os.system('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw)
3841

3942
def _load_name_to_index(self):
4043
self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)}

vqa/datasets/images.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, root, transform=None, loader=default_loader):
3838

3939
def __getitem__(self, index):
4040
item = {}
41-
item['name'] = self.imgs[index]
41+
item['name'] = self.imgs[index]
4242
item['path'] = os.path.join(self.root, item['name'])
4343
if self.loader is not None:
4444
item['visual'] = self.loader(item['path'])
@@ -57,11 +57,14 @@ def __init__(self, data_split, opt, transform=None, loader=default_loader):
5757
self.opt = opt
5858
self.transform = transform
5959
self.loader = loader
60-
6160
self.dir_raw = os.path.join(self.opt['dir'], 'raw')
62-
if not os.path.exists(self.dir_raw):
61+
62+
if not os.path.exists(self.get_dir_data()):
6363
self._raw()
6464

65+
def get_dir_data(self):
66+
return self.dir_raw
67+
6568
def get_by_name(self, image_name):
6669
index = self.name_to_index[image_name]
6770
return self[index]
@@ -73,4 +76,4 @@ def __getitem__(self, index):
7376
raise NotImplementedError
7477

7578
def __len__(self):
76-
raise NotImplementedError
79+
raise NotImplementedError

vqa/models/convnets.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
import torch.nn as nn
44
import torchvision.models as pytorch_models
5-
65
import sys
76
sys.path.append('vqa/external/pretrained-models.pytorch')
87
import pretrainedmodels as torch7_models
@@ -21,7 +20,21 @@
2120
def factory(opt, cuda=True, data_parallel=True):
2221
opt = copy.copy(opt)
2322

24-
# forward_* will be better handle in futur release
23+
class WrapperModule(nn.Module):
24+
def __init__(self, net, forward_fn):
25+
super(WrapperModule, self).__init__()
26+
self.net = net
27+
self.forward_fn = forward_fn
28+
29+
def forward(self, x):
30+
return self.forward_fn(self.net, x)
31+
32+
def __getattr__(self, attr):
33+
try:
34+
return super(WrapperModule, self).__getattr__(attr)
35+
except AttributeError:
36+
return getattr(self.net, attr)
37+
2538
def forward_resnet(self, x):
2639
x = self.conv1(x)
2740
x = self.bn1(x)
@@ -58,26 +71,23 @@ def forward_resnext(self, x):
5871
if opt['arch'] in pytorch_resnet_names:
5972
model = pytorch_models.__dict__[opt['arch']](pretrained=True)
6073

61-
convnet = model # ugly hack in case of DataParallel wrapping
62-
model.forward = lambda x: forward_resnet(convnet, x)
74+
model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping
6375

6476
elif opt['arch'] == 'fbresnet152':
6577
model = torch7_models.__dict__[opt['arch']](num_classes=1000,
6678
pretrained='imagenet')
6779

68-
convnet = model # ugly hack in case of DataParallel wrapping
69-
model.forward = lambda x: forward_resnet(convnet, x)
80+
model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping
7081

7182
elif opt['arch'] in torch7_resnet_names:
7283
model = torch7_models.__dict__[opt['arch']](num_classes=1000,
7384
pretrained='imagenet')
74-
75-
convnet = model # ugly hack in case of DataParallel wrapping
76-
model.forward = lambda x: forward_resnext(convnet, x)
85+
86+
model = WrapperModule(model, forward_resnext) # ugly hack in case of DataParallel wrapping
7787

7888
else:
7989
raise ValueError
80-
90+
8191
if data_parallel:
8292
model = nn.DataParallel(model).cuda()
8393
if not cuda:
@@ -86,4 +96,4 @@ def forward_resnext(self, x):
8696
if cuda:
8797
model.cuda()
8898

89-
return model
99+
return model

0 commit comments

Comments
 (0)