-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest.py
executable file
·184 lines (140 loc) · 6.71 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# general packages
import argparse
import numpy as np
import random
import os
import errno
import time
import math
from collections import defaultdict
import json
# torch
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm
import torch.distributed as dist
import torch.utils.data.distributed
import torchvision.transforms as transforms
# util
from data.yc2_test_dataset import Yc2TestDataset, yc2_test_collate_fn
from model.dvsa import DVSA
from tools.test_util import compute_pred
from tools.codalab_eval_grd_yc2bb import YC2BBGrdEval
parser = argparse.ArgumentParser()
# Data input settings
parser.add_argument('--start_from', default='', help='path to a model checkpoint to initialize model weights from. Empty = dont')
parser.add_argument('--box_file', default='./data/yc2/annotations/yc2_bb_val_annotations.json', help='annotation data used for evaluation, must match --val_split')
parser.add_argument('--val_split', default=['validation'], type=str, nargs='+', help='data split used for testing')
parser.add_argument('--num_workers', default=6, type=int)
parser.add_argument('--num_class', default=67, type=int)
parser.add_argument('--class_file', default='./data/class_file.csv', type=str)
parser.add_argument('--rpn_proposal_root', default='./data/yc2/roi_box', type=str)
parser.add_argument('--roi_pooled_feat_root', default='./data/yc2/roi_pooled_feat', type=str)
# Model settings: General
parser.add_argument('--num_proposals', default=20, type=int)
parser.add_argument('--enc_size', default=128, type=int)
parser.add_argument('--accu_thresh', default=0.5, type=float)
parser.add_argument('--num_frm', default=5, type=int)
# Model settings: Object Interaction
parser.add_argument('--hidden_size', default=256, type=int)
parser.add_argument('--n_layers', default=1, type=int)
parser.add_argument('--n_heads', default=4, type=int)
parser.add_argument('--attn_drop', default=0.2, type=float, help='dropout for the object interaction transformer layer')
# Optimization: General
parser.add_argument('--valid_batch_size', default=1, type=int)
parser.add_argument('--vis_dropout', default=0.2, type=float, help='dropout for the visual embedding layer')
parser.add_argument('--seed', default=123, type=int, help='random number generator seed to use')
parser.add_argument('--cuda', dest='cuda', action='store_true', help='use gpu')
# Data submisison
parser.add_argument('--save_to', default='./res/submission_yc2_bb.json', help='Save predictions to this JSON file')
parser.set_defaults(cuda=False)
args = parser.parse_args()
# arguments inspection
assert(args.valid_batch_size == 1)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if args.cuda:
torch.cuda.manual_seed_all(args.seed)
test_mode = 'testing' in args.val_split
def get_dataset(args):
valid_dataset = Yc2TestDataset(args.class_file, args.val_split,\
None, args.box_file, num_proposals=args.num_proposals, \
rpn_proposal_root=args.rpn_proposal_root, \
roi_pooled_feat_root=args.roi_pooled_feat_root, \
test_mode = test_mode)
valid_loader = DataLoader(valid_dataset,
batch_size=args.valid_batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=yc2_test_collate_fn)
return valid_loader
def get_model(args):
model = DVSA(args.num_class, enc_size=args.enc_size, dropout=args.vis_dropout, \
hidden_size=args.hidden_size, n_layers=args.n_layers, n_heads=args.n_heads, \
attn_drop=args.attn_drop, num_frm=args.num_frm, has_loss_weighting=True)
# Initialize the networks and the criterion
if len(args.start_from) > 0:
print("Initializing weights from {}".format(args.start_from))
checkpoint = torch.load(args.start_from,map_location=lambda storage, location: storage)
model.load_state_dict(checkpoint)
# Ship the model to GPU
if args.cuda:
model = model.cuda()
return model
def main(args):
print('loading dataset')
valid_loader = get_dataset(args)
print('building model')
model = get_model(args)
valid(model, valid_loader)
def valid(model, loader):
model.eval() # evaluation mode
ba_score = defaultdict(list) # box accuracy metric
class_labels_dict = loader.dataset.get_class_labels_dict() #dictionary for class labels - indices to strings
json_data = {}
database = {}
for iter, data in enumerate(loader):
print('evaluating iter {}...'.format(iter))
(x_rpn_batch, obj_batch, box_batch, box_label_batch,
_, rpn_batch, rpn_original_batch, vis_name) = data
x_rpn_batch = Variable(x_rpn_batch)
obj_batch = Variable(obj_batch)
rpn_batch = Variable(rpn_batch)
if args.cuda:
x_rpn_batch = x_rpn_batch.cuda()
obj_batch = obj_batch.cuda()
box_batch = box_batch.cuda()
box_label_batch = box_label_batch.cuda()
rpn_batch = rpn_batch.cuda() # N, num_frames, num_proposals, 4
rpn_original_batch = rpn_original_batch.cuda() # w/o coordinate normalization
# divide long segment into pieces
attn_weights = model.output_attn(x_rpn_batch, obj_batch).data
# quantitative results
segment_dict = compute_pred(attn_weights, rpn_original_batch, box_batch, obj_batch.data, \
box_label_batch, vis_name, thresh=args.accu_thresh, class_labels_dict=class_labels_dict)
split, rec, video_name, segment = vis_name.split('_-_')
if video_name not in database:
database[video_name] = {}
database[video_name]['recipe_type'] = rec
if 'segments' not in database[video_name]:
database[video_name]['segments'] = {}
database[video_name]['segments'][int(segment)] = segment_dict
json_data['database'] = database
if not os.path.isdir('res'):
os.mkdir('res')
with open(args.save_to,'w') as f:
json.dump(json_data,f)
print('Submission file saved to: {}'.format(args.save_to))
if not test_mode: #Annotations for the testing split are not publicly available
submit_file = args.save_to
ref_file = args.box_file
class_file = args.class_file
grd_evaluator = YC2BBGrdEval(reference_file=ref_file, submission_file=submit_file, class_file=class_file, iou_thresh=0.5, verbose=False)
grd_accu = grd_evaluator.gt_grd_eval()
if __name__ == "__main__":
main(args)