From 4a116b3012e528034e08e76b87fd9bcf92fcd5a4 Mon Sep 17 00:00:00 2001 From: zhn <775497320@qq.com> Date: Wed, 22 Jun 2022 12:29:38 +0000 Subject: [PATCH] first commit --- Dataloder_iterative.py | 328 ++++++++++++++++++++ config.py | 111 +++++++ configs/msrvtt_qa.yml | 37 +++ configs/msvd_qa.yml | 37 +++ configs/tgif_qa_action.yml | 34 +++ configs/tgif_qa_count.yml | 34 +++ configs/tgif_qa_frameqa.yml | 34 +++ configs/tgif_qa_transition.yml | 33 ++ init_glove.py | 160 ++++++++++ model/PKOL.py | 499 ++++++++++++++++++++++++++++++ model/retrieve_model.py | 359 ++++++++++++++++++++++ model/utils.py | 31 ++ readme.md | 117 +++++++ requirements.txt | 87 ++++++ train_iterative.py | 540 +++++++++++++++++++++++++++++++++ utils.py | 8 + validate_iterative.py | 440 +++++++++++++++++++++++++++ 17 files changed, 2889 insertions(+) create mode 100644 Dataloder_iterative.py create mode 100644 config.py create mode 100644 configs/msrvtt_qa.yml create mode 100644 configs/msvd_qa.yml create mode 100644 configs/tgif_qa_action.yml create mode 100644 configs/tgif_qa_count.yml create mode 100644 configs/tgif_qa_frameqa.yml create mode 100644 configs/tgif_qa_transition.yml create mode 100644 init_glove.py create mode 100644 model/PKOL.py create mode 100644 model/retrieve_model.py create mode 100644 model/utils.py create mode 100644 readme.md create mode 100644 requirements.txt create mode 100644 train_iterative.py create mode 100644 utils.py create mode 100644 validate_iterative.py diff --git a/Dataloder_iterative.py b/Dataloder_iterative.py new file mode 100644 index 0000000..2972486 --- /dev/null +++ b/Dataloder_iterative.py @@ -0,0 +1,328 @@ +# DISTRIBUTION STATEMENT A. Approved for public release: distribution unlimited. +# +# This material is based upon work supported by the Assistant Secretary of Defense for Research and +# Engineering under Air Force Contract No. FA8721-05-C-0002 and/or FA8702-15-D-0001. Any opinions, +# findings, conclusions or recommendations expressed in this material are those of the author(s) and +# do not necessarily reflect the views of the Assistant Secretary of Defense for Research and +# Engineering. +# +# © 2017 Massachusetts Institute of Technology. +# +# MIT Proprietary, Subject to FAR52.227-11 Patent Rights - Ownership by the contractor (May 2014) +# +# The software/firmware is provided to you on an As-Is basis +# +# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or +# 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are +# defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than +# as specifically authorized by the U.S. Government may violate any copyrights that exist in this +# work. + +import logging +import numpy as np +import json +import pickle +import torch +import math +import h5py +import random +from random import choice +# from torch._C import dtype, float32 +from torch.utils.data import Dataset, DataLoader +import torch.nn.functional as F + + +def invert_dict(d): + return {v: k for k, v in d.items()} + +def load_vocab(path): + with open(path, 'r') as f: + vocab = json.load(f) + vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx']) + vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) + vocab['question_answer_idx_to_token'] = invert_dict(vocab['question_answer_token_to_idx']) + return vocab + +class VideoQADataset(Dataset): + + def __init__(self, answers, ans_candidates, ans_candidates_len, questions, questions_len, video_ids, q_ids, + app_feature_h5, app_feat_id_to_index, motion_feature_h5, motion_feat_id_to_index, object_feature_h5,object_feat_id_to_index, + caption_path = None, caption_max_num = None, split = None, video_names = None, question_type = None): + # convert data to tensor + self.all_answers = answers + self.all_questions = torch.LongTensor(np.asarray(questions)) + self.all_questions_len = torch.LongTensor(np.asarray(questions_len)) + self.all_video_ids = torch.LongTensor(np.asarray(video_ids)) + self.all_video_names = video_names # dtype: str + self.all_q_ids = q_ids + self.app_feature_h5 = app_feature_h5 + self.motion_feature_h5 = motion_feature_h5 + self.object_feature_h5 = object_feature_h5 + self.app_feat_id_to_index = app_feat_id_to_index + self.motion_feat_id_to_index = motion_feat_id_to_index + self.object_feat_id_to_index = object_feat_id_to_index + self.caption_path = caption_path + self.caption_max_num = caption_max_num + self.caption_pool = [] + self.caption_pool_len = [] + self.video_idx2_cap_gt = {} + self.split = split + self.sample_caption = {} + self.visualization = [] + self.max_word = 40 + self.different_dataset = question_type + count = 0 + + with open('data/msrvtt-qa/msrvtt-qa_val_questions.pt', 'rb') as f: + obj = pickle.load(f) + val_ids = obj['video_ids'] + val_id = torch.LongTensor(np.asarray(val_ids)) + + + if self.different_dataset == 'none': + # matching file : question_ID video_ID caption_ID + with open('data/msrvtt-qa/captions_pkl/full_caption_len.pkl','rb') as g: + cap_len = pickle.load(g) + + with open(caption_path,'rb') as f: + self.caption = pickle.load(f) # 10000 {'video4118':[[],[]]} + + for vid,feat in self.caption.items(): + if 'msrvtt' in self.caption_path: + video_idx = int(vid) + elif 'msvd' in self.caption_path: + video_idx = int(vid[3:]) + # if video_idx in self.all_video_ids or video_idx in val_id: + # continue + if video_idx not in self.all_video_ids: + continue + gt = [] + max_word = self.max_word + for k, cap in enumerate(feat) : + self.visualization.append((video_idx, k)) + if 'msvd' in self.caption_path: + if self.split == 'train': + self.sample_caption.setdefault(video_idx, []).append((cap, torch.clamp(torch.tensor(cap.shape[0]),max=max_word).data)) + + gt.append(count) + count += 1 + padding = torch.zeros(max_word,cap.shape[1]) + + self.caption_pool_len.append(torch.clamp(torch.tensor(cap.shape[0]),max=max_word).data) + + padding[:cap.shape[0],:] = torch.from_numpy(cap)[:max_word,:] + self.caption_pool.append(padding.unsqueeze(0)) + + else: + self.sample_caption.setdefault(video_idx, []).append((cap, cap_len[vid][k])) + + gt.append(count) + count += 1 + padding = torch.zeros(max_word,cap.shape[1]) + + self.caption_pool_len.append(cap_len[vid][k]) + + padding[:cap.shape[0],:] = torch.from_numpy(cap)[:max_word,:] + self.caption_pool.append(padding.unsqueeze(0)) + self.video_idx2_cap_gt[str(video_idx)] = gt + + self.caption_pool = torch.cat(self.caption_pool,dim=0) # num_cap 61 768/300 + self.caption_pool_len = torch.tensor(self.caption_pool_len) + else: # T-gif + with open(caption_path,'r') as f: + self.caption = json.load(f) # {'video_name': [description,[sentence_index]]} + for vid, cap_index in self.caption.items(): + if int(vid) not in self.all_video_ids: + continue + padding = F.pad(torch.tensor(cap_index), pad=(0,self.max_word-len(cap_index))) + + if self.split == 'train': + self.sample_caption.setdefault(vid, []).append((padding, torch.clamp(torch.tensor(len(cap_index)),max=self.max_word).data)) + + self.caption_pool_len.append(torch.clamp(torch.tensor(len(cap_index)),max=self.max_word).data) + self.caption_pool.append(padding.unsqueeze(0)) + + self.video_idx2_cap_gt.setdefault(vid,[]).append(count) + count += 1 + + self.caption_pool = torch.cat(self.caption_pool, dim=0) + self.caption_pool_len = torch.tensor(self.caption_pool_len) + logging.info("length of caption pool:{}".format(self.caption_pool.size())) + logging.info("length of caption pool:{}".format(len(self.caption_pool))) + + if not np.any(ans_candidates): # [0,0,0,0,0] -> False + self.question_type = 'openended' + else: + self.question_type = 'mulchoices' + self.all_ans_candidates = torch.LongTensor(np.asarray(ans_candidates)) + self.all_ans_candidates_len = torch.LongTensor(np.asarray(ans_candidates_len)) + + def __getitem__(self, index): + answer = self.all_answers[index] if self.all_answers is not None else None + ans_candidates = torch.zeros(5) + ans_candidates_len = torch.zeros(5) + if self.question_type == 'mulchoices': + ans_candidates = self.all_ans_candidates[index] + ans_candidates_len = self.all_ans_candidates_len[index] + question = self.all_questions[index] + question_len = self.all_questions_len[index] + video_idx = self.all_video_ids[index].item() + video_name = self.all_video_names[index] + question_idx = self.all_q_ids[index] + ##### random sample captions + + if self.split == 'train': + if self.different_dataset == 'none': + sample_list = self.sample_caption[video_idx] #[(cap1,caplen1),(cap2,caplen2)] + + sample_cap, sample_cap_len = random.sample(sample_list, 1)[0] + + caption = torch.zeros(self.max_word,sample_cap.shape[1]) + + caption[:sample_cap.shape[0],:] = torch.from_numpy(sample_cap)[:self.max_word,:] + + caption_len = torch.as_tensor(sample_cap_len) + else: + sample_cap, sample_cap_len = self.sample_caption[str(video_idx)][0] # ([index1,index2,...],length) + + caption = sample_cap + + caption_len = torch.as_tensor(sample_cap_len) + + ##### random sample captions + app_index = self.app_feat_id_to_index[str(video_idx)] + motion_index = self.motion_feat_id_to_index[str(video_idx)] + object_index = self.object_feat_id_to_index[str(video_idx)] + with h5py.File(self.app_feature_h5, 'r') as f_app: + appearance_feat = f_app['resnet_features'][app_index] # (8, 16, 2048) + # if 'msrvtt' in self.app_feature_h5: + # Subtraction_frame = np.linspace(0, 16, num=8, endpoint=False, dtype=int) + # appearance_feat = appearance_feat[:, Subtraction_frame, :] + with h5py.File(self.motion_feature_h5, 'r') as f_motion: + motion_feat = f_motion['resnext_features'][motion_index] # (8, 2048) + with h5py.File(self.object_feature_h5,'r') as f_object: + object_feat = f_object['feat'][object_index] # (128,10,2048) + appearance_feat = torch.from_numpy(appearance_feat) + motion_feat = torch.from_numpy(motion_feat) + object_feat = torch.from_numpy(object_feat).to(torch.float32) + + if self.split == 'train': + return ( + video_idx, question_idx, answer, ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question, + question_len, caption, caption_len) + else: + return ( + video_idx, question_idx, answer, ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question, + question_len) + + def __len__(self): + return len(self.all_questions) + + +class VideoQADataLoader(DataLoader): + + def __init__(self, **kwargs): + vocab_json_path = str(kwargs.pop('vocab_json')) + print('loading vocab from %s' % (vocab_json_path)) + vocab = load_vocab(vocab_json_path) + ##################### load caption features ##################### + caption_path = None + dataset_name = kwargs.pop('name') + split = kwargs.pop('split') + caption_max_num = kwargs.pop('caption_max_num') + #if split == 'train': + if dataset_name == 'msrvtt-qa': + caption_path = 'data/msrvtt-qa/captions_pkl/full_caption_features.pkl' + #caption_path = 'data/msrvtt-qa/data/MSRVTT/structured-symlinks/aggregated_text_feats/w2v_MSRVTT.pickle' + if dataset_name == 'msvd-qa': + caption_path = 'data/msvd-qa/data/MSVD/structured-symlinks/aggregated_text_feats/openai-caption-full.pkl' + if dataset_name == 'tgif-qa': + caption_path = 'data/tgif-qa/tgif-caption/tgif_video_cap_ids.json' + + question_pt_path = str(kwargs.pop('question_pt')) + print('loading questions from %s' % (question_pt_path)) + question_type = kwargs.pop('question_type') + with open(question_pt_path, 'rb') as f: + obj = pickle.load(f) + questions = obj['questions'] + questions_len = obj['questions_len'] + video_ids = obj['video_ids'] + video_names = obj['video_names'] + q_ids = obj['question_id'] + answers = obj['answers'] + glove_matrix = obj['glove'] + ans_candidates = np.zeros(5) + ans_candidates_len = np.zeros(5) + if question_type in ['action', 'transition']: + ans_candidates = obj['ans_candidates'] + ans_candidates_len = obj['ans_candidates_len'] + + if 'train_num' in kwargs: + trained_num = kwargs.pop('train_num') + if trained_num > 0: + questions = questions[:trained_num] + questions_len = questions_len[:trained_num] + video_ids = video_ids[:trained_num] + q_ids = q_ids[:trained_num] + answers = answers[:trained_num] + if question_type in ['action', 'transition']: + ans_candidates = ans_candidates[:trained_num] + ans_candidates_len = ans_candidates_len[:trained_num] + if 'val_num' in kwargs: + val_num = kwargs.pop('val_num') + if val_num > 0: + questions = questions[:val_num] + questions_len = questions_len[:val_num] + video_ids = video_ids[:val_num] + q_ids = q_ids[:val_num] + answers = answers[:val_num] + if question_type in ['action', 'transition']: + ans_candidates = ans_candidates[:val_num] + ans_candidates_len = ans_candidates_len[:val_num] + if 'test_num' in kwargs: + test_num = kwargs.pop('test_num') + if test_num > 0: + questions = questions[:test_num] + questions_len = questions_len[:test_num] + video_ids = video_ids[:test_num] + q_ids = q_ids[:test_num] + answers = answers[:test_num] + if question_type in ['action', 'transition']: + ans_candidates = ans_candidates[:test_num] + ans_candidates_len = ans_candidates_len[:test_num] + + print('loading appearance feature from %s' % (kwargs['appearance_feat'])) + with h5py.File(kwargs['appearance_feat'], 'r') as app_features_file: + app_video_ids = app_features_file['ids'][()] + app_feat_id_to_index = {str(id): i for i, id in enumerate(app_video_ids)} + print('loading motion feature from %s' % (kwargs['motion_feat'])) + with h5py.File(kwargs['motion_feat'], 'r') as motion_features_file: + motion_video_ids = motion_features_file['ids'][()] + motion_feat_id_to_index = {str(id): i for i, id in enumerate(motion_video_ids)} + print('loading object feature from %s' % (kwargs['object_feat'])) + with h5py.File(kwargs['object_feat'], 'r') as object_features_file: + object_video_ids = object_features_file['video_ids'][()] + object_feat_id_to_index = {str(id): i for i, id in enumerate(object_video_ids)} + + self.app_feature_h5 = kwargs.pop('appearance_feat') + self.motion_feature_h5 = kwargs.pop('motion_feat') + self.object_feature_h5 = kwargs.pop('object_feat') + + self.dataset = VideoQADataset(answers, ans_candidates, ans_candidates_len, + questions, questions_len,video_ids, q_ids, + self.app_feature_h5, app_feat_id_to_index, + self.motion_feature_h5, motion_feat_id_to_index, + self.object_feature_h5,object_feat_id_to_index, + caption_path, caption_max_num, split = split, + video_names = video_names, + question_type = question_type, + ) + + self.vocab = vocab + self.batch_size = kwargs['batch_size'] + self.glove_matrix = glove_matrix + + super().__init__(self.dataset, **kwargs) + + def __len__(self): + return math.ceil(len(self.dataset) / self.batch_size) diff --git a/config.py b/config.py new file mode 100644 index 0000000..9541208 --- /dev/null +++ b/config.py @@ -0,0 +1,111 @@ +from __future__ import division +from __future__ import print_function + +import numpy as np +from easydict import EasyDict as edict + +__C = edict() +cfg = __C + +__C.gpu_id = 0 +__C.num_workers = 4 +__C.multi_gpus = False +__C.seed = 666 +# training options +__C.train = edict() +__C.train.restore = False +__C.train.lr = 0.0001 +__C.train.batch_size = 32 +__C.train.max_epochs = 25 +__C.train.vision_dim = 2048 +__C.train.word_dim = 300 +__C.train.module_dim = 512 +__C.train.train_num = 0 # Default 0 for full train set +__C.train.restore = False +__C.train.glove = True +__C.train.k_max_frame_level = 16 +__C.train.k_max_clip_level = 8 +__C.train.spl_resolution = 1 +__C.train.caption_dim = 300 +__C.train = dict(__C.train) +__C.train.pretrained_retrieve_path = None +__C.train.joint = False +__C.train.patch_number = 20000 +__C.train.topk = 3 + +# validation +__C.val = edict() +__C.val.flag = True +__C.val.val_num = 0 # Default 0 for full val set +__C.val.topk = 10 +__C.val = dict(__C.val) +# test +__C.test = edict() +__C.test.test_num = 0 # Default 0 for full test set +__C.test.write_preds = False +__C.test.visualization = False +__C.test = dict(__C.test) +# dataset options +__C.dataset = edict() +__C.dataset.name = 'tgif-qa' # ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] +__C.dataset.question_type = 'none' #['frameqa', 'count', 'transition', 'action', 'none'] +__C.dataset.data_dir = '' +__C.dataset.appearance_feat = '{}_{}_appearance_feat.h5' +__C.dataset.motion_feat = '{}_{}_motion_feat.h5' +__C.dataset.object_feat = '{}_{}_object_feat.h5' +__C.dataset.vocab_json = '{}_{}_vocab.json' +__C.dataset.train_question_pt = '{}_{}_train_questions.pt' +__C.dataset.val_question_pt = '{}_{}_val_questions.pt' +__C.dataset.test_question_pt = '{}_{}_test_questions.pt' +__C.dataset.save_dir = '' +__C.dataset.topk = 10 +__C.dataset.pretrained = '' +__C.dataset.max_cap_num = 15 +__C.dataset = dict(__C.dataset) + +# experiment name +__C.exp_name = 'defaultExp' + +# credit https://github.com/tohinz/pytorch-mac-network/blob/master/code/config.py +def merge_cfg(yaml_cfg, cfg): + if type(yaml_cfg) is not edict: + return + + for k, v in yaml_cfg.items(): + if not k in cfg: + raise KeyError('{} is not a valid config key'.format(k)) + + old_type = type(cfg[k]) + if old_type is not type(v): + if isinstance(cfg[k], np.ndarray): + v = np.array(v, dtype=cfg[k].dtype) + elif isinstance(cfg[k], list): + v = v.split(",") + v = [int(_v) for _v in v] + elif cfg[k] is None: + if v == "None": + continue + else: + v = v + else: + raise ValueError(('Type mismatch ({} vs. {}) ' + 'for config key: {}').format(type(cfg[k]), + type(v), k)) + # recursively merge dicts + if type(v) is edict: + try: + merge_cfg(yaml_cfg[k], cfg[k]) + except: + print('Error under config key: {}'.format(k)) + raise + else: + cfg[k] = v + + + +def cfg_from_file(file_name): + import yaml + with open(file_name, 'r') as f: + yaml_cfg = edict(yaml.load(f)) + + merge_cfg(yaml_cfg, __C) \ No newline at end of file diff --git a/configs/msrvtt_qa.yml b/configs/msrvtt_qa.yml new file mode 100644 index 0000000..3e8b8c7 --- /dev/null +++ b/configs/msrvtt_qa.yml @@ -0,0 +1,37 @@ +gpu_id: 3 +multi_gpus: False +num_workers: 8 +seed: 666 +exp_name: 'expMSRVTT-QA' + +train: + lr: 0.0001 + batch_size: 16 + restore: False + max_epochs: 25 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 768 + joint : True + patch_number : 10000 # max :50000 + topk : 3 + +val: + flag: True + topk : 3 + +test: + test_num: 0 + write_preds: False + visualization: False + +dataset: + name: 'msrvtt-qa' + question_type: 'none' + data_dir: 'data/msrvtt-qa' + save_dir: 'results/' + max_cap_num: 19 # max: 19 \ No newline at end of file diff --git a/configs/msvd_qa.yml b/configs/msvd_qa.yml new file mode 100644 index 0000000..bbaf3f6 --- /dev/null +++ b/configs/msvd_qa.yml @@ -0,0 +1,37 @@ +gpu_id: 0 +multi_gpus: False +num_workers: 12 +seed: 666 +exp_name: 'expMSVD-QA' + +train: + lr: 0.0001 + batch_size: 128 + restore: False + max_epochs: 25 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 768 + joint : True + patch_number : 5000 # max :25000 + topk : 3 + +val: + flag: True + topk : 3 + +test: + test_num: 0 + write_preds: False + visualization: False + +dataset: + name: 'msvd-qa' + question_type: 'none' + data_dir: 'data/msvd-qa/' + save_dir: 'results/' + max_cap_num: 18 # max:18 \ No newline at end of file diff --git a/configs/tgif_qa_action.yml b/configs/tgif_qa_action.yml new file mode 100644 index 0000000..5600b62 --- /dev/null +++ b/configs/tgif_qa_action.yml @@ -0,0 +1,34 @@ +gpu_id: 0 +multi_gpus: False +num_workers: 8 +seed: 666 +exp_name: 'expTGIF-QAAction' + +train: + lr: 0.0001 + batch_size: 64 + restore: False + max_epochs: 30 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 300 + joint : True + patch_number : 25000 # max :25000 + topk : 3 +val: + flag: True + topk : 5 + +test: + test_num: 0 + write_preds: False + +dataset: + name: 'tgif-qa' + question_type: 'action' + data_dir: 'data/tgif-qa/action' + save_dir: 'results/' \ No newline at end of file diff --git a/configs/tgif_qa_count.yml b/configs/tgif_qa_count.yml new file mode 100644 index 0000000..3aa4afc --- /dev/null +++ b/configs/tgif_qa_count.yml @@ -0,0 +1,34 @@ +gpu_id: 0 +multi_gpus: False +num_workers: 8 +seed: 666 +exp_name: 'expTGIF-QACount' + +train: + lr: 0.0001 + batch_size: 64 + restore: False + max_epochs: 30 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 300 + joint : True + patch_number : 25000 # max :25000 + topk : 3 +val: + flag: True + topk : 5 + +test: + test_num: 0 + write_preds: False + +dataset: + name: 'tgif-qa' + question_type: 'count' + data_dir: 'data/tgif-qa/count' + save_dir: 'results/' \ No newline at end of file diff --git a/configs/tgif_qa_frameqa.yml b/configs/tgif_qa_frameqa.yml new file mode 100644 index 0000000..37fe4fe --- /dev/null +++ b/configs/tgif_qa_frameqa.yml @@ -0,0 +1,34 @@ +gpu_id: 0 +multi_gpus: False +num_workers: 8 +seed: 666 +exp_name: 'expTGIF-QAFrameQA' + +train: + lr: 0.0001 + batch_size: 64 + restore: False + max_epochs: 30 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 300 + joint : True + patch_number : 25000 # max :25000 + topk : 3 + +val: + flag: True + topk : 5 +test: + test_num: 0 + write_preds: False + +dataset: + name: 'tgif-qa' + question_type: 'frameqa' + data_dir: 'data/tgif-qa/frameqa' + save_dir: 'results/' \ No newline at end of file diff --git a/configs/tgif_qa_transition.yml b/configs/tgif_qa_transition.yml new file mode 100644 index 0000000..63b5cb6 --- /dev/null +++ b/configs/tgif_qa_transition.yml @@ -0,0 +1,33 @@ +gpu_id: 0 +multi_gpus: False +num_workers: 8 +seed: 666 +exp_name: 'expTGIF-QATransition' + +train: + lr: 0.0001 + batch_size: 64 + restore: False + max_epochs: 30 + word_dim: 300 + module_dim: 512 + glove: True + k_max_frame_level: 16 + k_max_clip_level: 8 + spl_resolution: 1 + caption_dim : 300 + joint : True + patch_number : 25000 # max :25000 + topk : 3 +val: + flag: True + topk : 3 +test: + test_num: 0 + write_preds: False + +dataset: + name: 'tgif-qa' + question_type: 'transition' + data_dir: 'data/tgif-qa/transition' + save_dir: 'results/' \ No newline at end of file diff --git a/init_glove.py b/init_glove.py new file mode 100644 index 0000000..5070919 --- /dev/null +++ b/init_glove.py @@ -0,0 +1,160 @@ +from genericpath import exists +from json.decoder import JSONDecodeError +import os +import sys +import numpy as np +import collections + +import json + +from torchtext.data.utils import get_tokenizer + +class Caption_vocabulary(object): + """ + A simple Vocabulary class which maintains a mapping between words and integer tokens. Can be + initialized either by word counts from the tgif-qa dataset, or a pre-saved vocabulary mapping. + + Parameters + ---------- + word_counts_path: str + Path to a json file containing counts of each word across captions, questions and answers + of the VisDial v1.0 train dataset. + min_count : int, optional (default=0) + When initializing the vocabulary from word counts, you can specify a minimum count, and + every token with a count less than this will be excluded from vocabulary. + """ + PAD = '' + UNK = '' + SOS = '' + EOS = '' + + PAD_index = 0 + UNK_index = 1 + SOS_index = 2 + EOS_index = 3 + + def __init__(self, word_counts_json, file_exist = '') -> None: + if not exists(word_counts_json): + raise FileNotFoundError(f'Word counts do not exist at {word_counts_json}') + if file_exist == '': + with open(word_counts_json,'r') as word_counts_file: + word_counts = json.load(word_counts_file) + word_counts = [ + (word, count) for word,count in word_counts.items() + ] + words = [w[0] for w in word_counts] + + self.word2index = {} + self.word2index.setdefault(self.PAD,0) + self.word2index.setdefault(self.UNK,1) + self.word2index.setdefault(self.SOS,2) + self.word2index.setdefault(self.EOS,3) + + for index, word in enumerate(words): + self.word2index.setdefault(word, index + 4) + else: + with open(file_exist,'r') as F: + f = json.load(F) + self.word2index = f + + self.index2word = {index: word for word,index in self.word2index.items()} + + # json.dump(self.word2index, open('data/tgif-qa/tgif-caption/tgif_word2index.json', 'w')) + + + @staticmethod + def _tokenizer(raw_file): + """ + A simple tokenizer to convert caption sentences to tokens + + Parameters + ---------- + raw_file: str + Path to a json file containing name of each video file and its corresponding description. + + """ + + tokenizer=get_tokenizer('basic_english') + with open('data/tgif-qa/tgif-caption/tgif-caption.json','r') as f: + cap_file = json.load(f) # {gif_name: description} + + word_counts = collections.defaultdict(lambda :0) + for _, des in cap_file.items(): + token=tokenizer(des) + for tok in token: + word_counts[tok] += 1 + + json.dump(word_counts,open('data/tgif-qa/tgif-caption/tgif-word_counts.json','w')) + + +class GloveProcessor(object): + def __init__(self, glove_path): + self.glove_path = glove_path + + def _load_glove_model(self): + print("Loading pretrained word vectors...") + with open(self.glove_path, 'r') as f: + model = {} + for line in f: + splitLine = line.split() + word = splitLine[0] + embedding = np.array([float(val) for val in splitLine[1:]]) # e.g., 300 dimension + model[word] = embedding + + print("Done.", len(model), " words loaded from %s" % self.glove_path) + + return model + + def save_glove_vectors(self, vocabulary, glove_npy_path, dim=300): + """ + Saves glove vectors in numpy array + Args: + vocab: dictionary vocab[word] = index + glove_filename: a path to a glove file + trimmed_filename: a path where to store a matrix in npy + dim: (int) dimension of embeddings + """ + # vocabulary index2word + vocab_size = len(vocabulary.index2word) + glove_embeddings = self._load_glove_model() + embeddings = np.zeros(shape=[vocab_size, 300], dtype=np.float32) + + vocab_in_glove = 0 + for i in range(0, vocab_size): + word = vocabulary.index2word[i] + if word in ['', '', '']: + continue + if word in glove_embeddings: + embeddings[i] = glove_embeddings[word] + vocab_in_glove += 1 + else: + embeddings[i] = glove_embeddings['unk'] + + print("Vocabulary in GLoVE : %d / %d" % (vocab_in_glove, vocab_size)) + np.save(glove_npy_path, embeddings) + + +if __name__ == '__main__': + # vocabulary = Caption_vocabulary('data/tgif-qa/tgif-caption/tgif_word_counts.json', file_exist='data/tgif-qa/tgif-caption/tgif_word2index.json') + # glove_vocab = GloveProcessor('/mnt/hdd1/zhanghaonan/code/MVAN-VisDial-master/glove.6B.300d.txt') + # glove_vocab.save_glove_vectors(vocabulary, 'data/tgif-qa/tgif-caption/glove.npy') + + with open('data/tgif-qa/tgif-caption/tgif_word2index.json', 'r') as mapping: + map = json.load(mapping) + + + + tokenizer=get_tokenizer('basic_english') + with open('data/tgif-qa/tgif-caption/tgif_caption.json','r') as f: + cap_file = json.load(f) # {gif_name: description} + + res = collections.defaultdict(lambda :[]) + + for vid, des in cap_file.items(): + token=tokenizer(des) + res[vid].append(des) + w = [] + for tok in token: + w.append(map.get(tok,1)) + res[vid].append(w) + json.dump(dict(res),open('data/tgif-qa/tgif-caption/tgif_cap_index.json','w')) \ No newline at end of file diff --git a/model/PKOL.py b/model/PKOL.py new file mode 100644 index 0000000..f46814c --- /dev/null +++ b/model/PKOL.py @@ -0,0 +1,499 @@ +import numpy as np +from torch.nn import functional as F +from .utils import * + +class FeatureAggregation(nn.Module): + def __init__(self, module_dim=512): + super(FeatureAggregation, self).__init__() + self.module_dim = module_dim + + self.q_proj = nn.Linear(module_dim, module_dim, bias=False) + self.v_proj = nn.Linear(module_dim*2, module_dim, bias=False) + + self.cat = nn.Linear(3 * module_dim, module_dim*2) + + self.activation = nn.ELU() + self.dropout = nn.Dropout(0.15) + + def forward(self, question_rep, visual_feat): + + visual_feat = self.dropout(visual_feat) + q_proj = self.q_proj(question_rep) + v_proj = self.v_proj(visual_feat) + + v_q_cat = q_proj * v_proj + + v_distill = torch.cat([v_q_cat, visual_feat],dim=-1) + + v_distill = self.activation(self.cat(v_distill)) + + return v_distill + + +class Global_FeatureAggregation(nn.Module): + def __init__(self, module_dim=512): + super(Global_FeatureAggregation, self).__init__() + self.module_dim = module_dim + + self.m_proj = nn.Linear(module_dim*4, module_dim, bias=True) + self.o_proj = nn.Linear(module_dim*4, module_dim, bias=True) + + self.global_query = nn.Linear(module_dim*2, module_dim, bias=True) + + self.activation = nn.ELU() + self.dropout = nn.Dropout(0.15) + + self.global_att = nn.Linear(module_dim, 1) + self.obj_att = nn.Linear(module_dim, 1) + self.final_att = nn.Linear(module_dim,1) + + + def forward(self, motion_feat, obj_feat, query): + ''' + motion_feat: batch_size num_clip 2048 + obj_feat: batch_size num_frame num_obj 2048 + query: batch_size module + + ''' + bs, num_frame, num_obj, visual_dim = obj_feat.size() + _, num_clip, _ = motion_feat.size() + obj_feat = self.dropout(obj_feat) + mot_feat = self.dropout(motion_feat) + + m_proj = self.activation(self.m_proj(mot_feat)) # batch_size num_clip module + o_proj = self.activation(self.o_proj(obj_feat)) # batch_size num_frame num_obj module + query = self.activation(self.global_query(query)) # batch_size module + + m_att = self.global_att(m_proj*query.unsqueeze(1)) # batch_size num_clip 1 + m_score = F.softmax(m_att, dim=1) # batch_size num_clip 1 + + m = (m_score * m_proj).sum(1) # batch_size module + # print((m_score * m_proj).unsqueeze(-2).repeat(1,1,num_frame//num_clip,1).reshape(bs,-1,self.module_dim).unsqueeze(-2).size()) + o_att = self.obj_att(o_proj*query.unsqueeze(1).repeat(1, num_frame, 1).unsqueeze(2)) + # o_att = self.obj_att(o_proj*(m_score * m_proj).unsqueeze(-2).repeat(1,1,num_frame//num_clip,1).reshape(bs,-1,self.module_dim).unsqueeze(-2)) + + o_score = F.softmax(o_att, dim=1) # batch_size num_frame num_obj 1 + + o = (o_score * o_proj).sum(1) # batch_size num_obj module + + global_o = m.unsqueeze(1)*o + # global_o = o * query.unsqueeze(1) + + final_att = self.final_att(global_o) + final_score = F.softmax(final_att, dim=1) + final = (final_score * global_o).sum(1) # batch_size module + + # print(final.size(),m.size()) + g = torch.cat([final, m],dim=-1) # batch_size module*2 + + return g + + +class Prospect_Background_aggregation(nn.Module): + def __init__(self, module_dim = 512): + super(Prospect_Background_aggregation,self).__init__() + self.module_dim = module_dim + + self.a_proj = nn.Linear(module_dim*4, module_dim, bias=True) + self.o_proj = nn.Linear(module_dim*4, module_dim, bias=True) + self.global_query = nn.Linear(module_dim*2, module_dim, bias=True) + + self.activation = nn.ELU() + self.dropout = nn.Dropout(0.15) + + self.global_att = nn.Linear(module_dim, 1) + self.obj_att = nn.Linear(module_dim, 1) + self.final_att = nn.Linear(module_dim,1) + + def forward(self, obj_feat, app_feat, query): + ''' + obj_feat: batch_size num_frame num_obj 2048 + app_feat: batch_size num_clip num_frame 2048 + query: batch_size module + ''' + obj_feat = self.dropout(obj_feat) + app_feat = self.dropout(app_feat) + + obj_feat = self.activation(self.o_proj(obj_feat)) + app_feat = self.activation(self.a_proj(app_feat)) + query =self.activation(self.global_query(query)) + + bs, num_frame, num_obj, visual_dim = obj_feat.size() + + bs, num_clip, num_clip_frame, vis_dim = app_feat.size() + + a_proj = app_feat.reshape(bs, -1, vis_dim) # bs num_frame 2048 + + a_att = self.global_att(a_proj*query.unsqueeze(1)) # batch_size n 1 + a_score = F.softmax(a_att, dim=1) # batch_size num_frame 1 + + a = (a_score * a_proj).sum(1) # batch_size module + + o_att = self.obj_att(obj_feat*query.unsqueeze(1).repeat(1, num_frame, 1).unsqueeze(2)) + # o_att = self.obj_att(obj_feat*(a_score * a_proj).unsqueeze(-2)) + o_score = F.softmax(o_att, dim=1) # batch_size num_frame num_obj 1 + + o = (o_score * obj_feat).sum(1) # batch_size num_obj module + + global_o = a.unsqueeze(1)*o + + # global_o = o * query.unsqueeze(1) + + final_att = self.final_att(global_o) + final_score = F.softmax(final_att, dim=1) + final = (final_score * global_o).sum(1) # batch_size module + # print(final.size(),m.size()) + app = torch.cat([final, a],dim=-1) # batch_size module*2 + + return app + +class InputUnitLinguistic(nn.Module): + def __init__(self, vocab_size, wordvec_dim=300, rnn_dim=512, module_dim=512, bidirectional=True): + super(InputUnitLinguistic, self).__init__() + + self.dim = module_dim + + self.bidirectional = bidirectional + if bidirectional: + rnn_dim = rnn_dim // 2 + + self.encoder_embed = nn.Embedding(vocab_size, wordvec_dim) + self.tanh = nn.Tanh() + self.encoder = nn.LSTM(wordvec_dim, rnn_dim, batch_first=True, bidirectional=bidirectional) + self.embedding_dropout = nn.Dropout(p=0.15) + self.question_dropout = nn.Dropout(p=0.18) + + self.module_dim = module_dim + + def forward(self, questions, question_len): + """ + Args: + question: [Tensor] (batch_size, max_question_length) + question_len: [Tensor] (batch_size) + return: + question representation [Tensor] (batch_size, module_dim) + """ + questions_embedding = self.encoder_embed(questions) # (batch_size, seq_len, dim_word) + embed = self.tanh(self.embedding_dropout(questions_embedding)) + + embed = nn.utils.rnn.pack_padded_sequence(embed, question_len.cpu(), batch_first=True, + enforce_sorted=False) + + self.encoder.flatten_parameters() + _, (question_embedding, _) = self.encoder(embed) + if self.bidirectional: + question_embedding = torch.cat([question_embedding[0], question_embedding[1]], -1) + question_embedding = self.question_dropout(question_embedding) + + return question_embedding + +class captionLinguistic(nn.Module): + def __init__(self, caption_dim, rnn_dim=512, module_dim=512, bidirectional=True): + super(captionLinguistic, self).__init__() + + self.dim = module_dim + + self.bidirectional = bidirectional + if bidirectional: + rnn_dim = rnn_dim // 2 + + self.encoder_embed = nn.Linear(caption_dim, rnn_dim) + self.tanh = nn.Tanh() + self.encoder = nn.LSTM(caption_dim, rnn_dim, batch_first=True, bidirectional=bidirectional) + self.embedding_dropout = nn.Dropout(p=0.15) + self.question_dropout = nn.Dropout(p=0.18) + + self.module_dim = module_dim + + def forward(self, caption, caption_len): + """ + Args: + caption: [Tensor] (batch_size, num_cap, max_question_length, cap_dim) + caption_len: [Tensor] (batch_size, num_cap) + return: + caption representation [Tensor] (batch_size, num_cap, module_dim) + """ + + caption_embedding = caption + embed = self.tanh(self.embedding_dropout(caption_embedding)) + + + embed_candi = nn.utils.rnn.pack_padded_sequence(embed, caption_len.cpu(), batch_first=True, + enforce_sorted=False) + self.encoder.flatten_parameters() + _, (caption_embedding, _) = self.encoder(embed_candi) + if self.bidirectional: + caption_embedding = torch.cat([caption_embedding[0], caption_embedding[1]], -1) + caption_embedding = self.question_dropout(caption_embedding) + + return caption_embedding + + +class OutputUnitOpenEnded(nn.Module): + def __init__(self, module_dim=512, num_answers=1000): + super(OutputUnitOpenEnded, self).__init__() + + self.question_proj = nn.Linear(module_dim, module_dim) + + self.classifier = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 2, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, num_answers)) + + def forward(self, question_embedding, visual_embedding): + question_embedding = self.question_proj(question_embedding) + out = torch.cat([visual_embedding, question_embedding], 1) + #out = self.classifier(out) + + return out + + +class OutputUnitMultiChoices(nn.Module): + def __init__(self, module_dim=512, caption_flag = False): + super(OutputUnitMultiChoices, self).__init__() + + self.question_proj = nn.Linear(module_dim, module_dim) + + self.ans_candidates_proj = nn.Linear(module_dim, module_dim) + + self.caption_flag = caption_flag + + if not self.caption_flag: + self.classifier = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 7, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, 1)) + else: + self.classifier = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 7, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, 1)) + + def forward(self, question_embedding, q_visual_embedding, ans_candidates_embedding, + a_visual_embedding, caption_embedding): + question_embedding = self.question_proj(question_embedding) + ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding) + + out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding, + ans_candidates_embedding], 1) + if self.caption_flag: + out = torch.cat([out,caption_embedding], 1) + out = self.classifier(out) + + return out + + +class OutputUnitCount(nn.Module): + def __init__(self, module_dim=512, caption_flag=True): + super(OutputUnitCount, self).__init__() + + self.question_proj = nn.Linear(module_dim, module_dim) + self.caption_flag = caption_flag + if not caption_flag: + self.regression = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 2, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, 1)) + else: + self.regression = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 6, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, 1)) + + def forward(self, question_embedding, visual_embedding, caption_embedding = None): + question_embedding = self.question_proj(question_embedding) + if self.caption_flag: + out = torch.cat([visual_embedding, question_embedding,caption_embedding], 1) + out = self.regression(out) + else: + out = torch.cat([visual_embedding, question_embedding], 1) + out = self.regression(out) + + return out + + +class PKOL_Net(nn.Module): + def __init__(self, vision_dim, module_dim, word_dim, k_max_frame_level, k_max_clip_level, spl_resolution, vocab, question_type, caption_dim, topk, corpus, corpus_len, patch_number, cap_vocab = None, visualization= False): + super(PKOL_Net, self).__init__() + + self.visualization = visualization + self.topk = topk + self.patch_number = patch_number + + self.corpus = corpus + self.corpus_len = corpus_len + self.question_type = question_type + + self.feature_aggregation_global = FeatureAggregation(module_dim) + self.feature_aggregation_PB = FeatureAggregation(module_dim) + self.Global_FeatureAggregation = Global_FeatureAggregation(module_dim) + self.Prospect_Background_aggregation = Prospect_Background_aggregation(module_dim) + + if self.question_type in ['action', 'transition']: + encoder_vocab_size = len(vocab['question_answer_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + self.output_unit = OutputUnitMultiChoices(module_dim=module_dim, caption_flag=True) + self.ffc = nn.Sequential( + nn.Linear(module_dim*4,module_dim*2), + nn.Dropout(p=0.2), + nn.Tanh() + ) + elif self.question_type == 'count': + encoder_vocab_size = len(vocab['question_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + self.output_unit = OutputUnitCount(module_dim=module_dim, caption_flag=True) + self.classifier = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 2, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, 1)) + else: + encoder_vocab_size = len(vocab['question_token_to_idx']) + self.num_classes = len(vocab['answer_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + self.caption_unit = captionLinguistic(caption_dim=caption_dim) + self.output_unit = OutputUnitOpenEnded(num_answers=self.num_classes) + + ######################################### + self.classifier = nn.Sequential(nn.Dropout(0.15), + nn.Linear(module_dim * 2, module_dim), + nn.ELU(), + nn.BatchNorm1d(module_dim), + nn.Dropout(0.15), + nn.Linear(module_dim, self.num_classes)) + self.att = nn.Sequential( + nn.Linear(module_dim,1), + nn.Dropout(p=0.2) + ) + self.fusion = nn.Sequential( + nn.Linear(module_dim*6,module_dim*2), + nn.Dropout(p=0.2), + nn.Tanh() + ) + + self.softmax = nn.Softmax(dim=-1) + + init_modules(self.modules(), w_init="xavier_uniform") + nn.init.uniform_(self.linguistic_input_unit.encoder_embed.weight, -1.0, 1.0) + + def forward(self, ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, video_object_feat, question, + question_len, similarity=None, corpus=None, corpus_len=None): + """ + Args: + ans_candidates: [Tensor] (batch_size, 5, max_ans_candidates_length) + ans_candidates_len: [Tensor] (batch_size, 5) + video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) + video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) + video_object_feat: [Tensor] (batch_size, num_objs, obj_dim) + question: [Tensor] (batch_size, max_question_length) + question_len: [Tensor] (batch_size) + simility: [Tensor] (batch_size, num_corpus) + corpus: [Tensor] train: (num_cap, num_word cap_dim) / val: (None) | [tgif] train: (num_cap, num_word) / val: (None) + corpus_len: [Tensor] train: (num_cap, ) / val: (None) | [tgif] train: (num_cap, ) / val: (None, ) + return: + logits. + """ + batch_size = question.size(0) + if self.question_type in ['frameqa', 'count', 'none']: + # get image, word, and sentence embeddings + question_embedding = self.linguistic_input_unit(question, question_len) # batch_size module_dim + + # corpus retrieve + + caption_tensor = corpus + sort_, index_ = similarity.sort(1,descending=True) # batch_size num_cap + + index_ = index_[:,:self.topk].contiguous().view(-1, 1) # batch_size*topk + + caption_tensor_list = caption_tensor[index_].contiguous().view(batch_size, self.topk, -1) # batch_size topk module_dim + + #####################caption&question-attention##################### + caption_awear_q = caption_tensor_list*question_embedding.unsqueeze(1).repeat(1,caption_tensor_list.size(1),1) + + caption_att = self.softmax(self.att(caption_awear_q).squeeze(-1)) # batch_size num_cap + + caption_feat = torch.sum(caption_att.unsqueeze(-1) * caption_tensor_list, dim=-2) # batch_size module_dim + + v = question_embedding + + contextual_content = torch.cat([v,caption_feat],dim=-1) # batch_size module_dim*2 + + #####################global-aggregation##################### + + global_embedding = self.Global_FeatureAggregation(video_motion_feat, video_object_feat, contextual_content) # (batch_size module*2) + prospect_Background_embedding = self.Prospect_Background_aggregation(video_object_feat, video_appearance_feat, contextual_content) # (batch_size module*2) + + g_embedding = self.feature_aggregation_global(question_embedding, global_embedding) + PB_embedding = self.feature_aggregation_PB(question_embedding, prospect_Background_embedding) + + + out = self.fusion(torch.cat([g_embedding, PB_embedding, contextual_content], dim=-1)) + out = self.classifier(out) + + else: + + question_embedding = self.linguistic_input_unit(question, question_len) # batch_size module_dim + + # # corpus retrieve + + caption_tensor = corpus + sort_, index_ = similarity.sort(1,descending=True) # batch_size num_cap + + index_ = index_[:,:self.topk].contiguous().view(-1, 1) # batch_size*topk + + caption_tensor_list = caption_tensor[index_].contiguous().view(batch_size, self.topk, -1) # batch_size topk module_dim + + #####################caption&question-attention##################### + caption_awear_q = caption_tensor_list*question_embedding.unsqueeze(1).repeat(1,caption_tensor_list.size(1),1) + caption_att = self.softmax(self.att(caption_awear_q).squeeze(-1)) # batch_size num_cap + + caption_feat = torch.sum(caption_att.unsqueeze(-1) * caption_tensor_list, dim=-2) # batch_size module_dim + + contextual_content = torch.cat([question_embedding,caption_feat], dim=-1) # batch_size module_dim*2 + + #####################global-aggregation##################### + + global_embedding = self.Global_FeatureAggregation(video_motion_feat, video_object_feat, contextual_content) # (batch_size module*2) + prospect_Background_embedding = self.Prospect_Background_aggregation(video_object_feat, video_appearance_feat, contextual_content) # (batch_size module*2) + + g_embedding = self.feature_aggregation_global(question_embedding, global_embedding) + PB_embedding = self.feature_aggregation_PB(question_embedding, prospect_Background_embedding) + + out = self.fusion(torch.cat([g_embedding, PB_embedding, contextual_content], dim=-1)) + + ans_candidates_agg = ans_candidates.view(-1, ans_candidates.size(2)) + ans_candidates_len_agg = ans_candidates_len.view(-1) + + batch_agg = np.reshape( + np.tile(np.expand_dims(np.arange(batch_size), axis=1), [1, 5]), [-1]) + + ans_candidates_embedding = self.linguistic_input_unit(ans_candidates_agg, ans_candidates_len_agg) + + a_global_embedding = self.feature_aggregation_global(ans_candidates_embedding, global_embedding[batch_agg]) + a_PB_embedding = self.feature_aggregation_PB(ans_candidates_embedding, prospect_Background_embedding[batch_agg]) + + a_visual_embedding = self.ffc(torch.cat([a_global_embedding, a_PB_embedding],dim=-1)) + + out = self.output_unit(question_embedding[batch_agg], out[batch_agg], + ans_candidates_embedding, + a_visual_embedding, caption_feat[batch_agg]) + if self.visualization: + return out, index_.reshape(batch_size,-1), caption_att + + return out + + diff --git a/model/retrieve_model.py b/model/retrieve_model.py new file mode 100644 index 0000000..2708d6a --- /dev/null +++ b/model/retrieve_model.py @@ -0,0 +1,359 @@ +import numpy as np +from torch.nn import functional as F +from torch.nn.modules import module +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from .utils import * + +class DynamicRNN(nn.Module): + def __init__(self, rnn_model): + super().__init__() + self.rnn_model = rnn_model + + def forward(self, seq_input, seq_lens, initial_state=None): + """A wrapper over pytorch's rnn to handle sequences of variable length. + + Arguments + --------- + seq_input : torch.Tensor + Input sequence tensor (padded) for RNN model. + Shape: (batch_size, max_sequence_length, embed_size) + seq_lens : torch.LongTensor + Length of sequences (b, ) + initial_state : torch.Tensor + Initial (hidden, cell) states of RNN model. + + Returns + ------- + Single tensor of shape (batch_size, rnn_hidden_size) corresponding + to the outputs of the RNN model at the last time step of each input + sequence. + """ + max_sequence_length = seq_input.size(1) + sorted_len, fwd_order, bwd_order = self._get_sorted_order(seq_lens) + sorted_seq_input = seq_input.index_select(0, fwd_order) + packed_seq_input = pack_padded_sequence( + sorted_seq_input, lengths=sorted_len, batch_first=True + ) + + if initial_state is not None: + hx = initial_state + assert hx[0].size(0) == self.rnn_model.num_layers + else: + sorted_hx = None + + self.rnn_model.flatten_parameters() + + outputs, (h_n, c_n) = self.rnn_model(packed_seq_input, sorted_hx) + + # pick hidden and cell states of last layer + h_n = h_n[-1].index_select(dim=0, index=bwd_order) + c_n = c_n[-1].index_select(dim=0, index=bwd_order) + + outputs = pad_packed_sequence( + outputs, batch_first=True, total_length=max_sequence_length + )[0].index_select(dim=0, index=bwd_order) + + return outputs, (h_n, c_n) + + @staticmethod + def _get_sorted_order(lens): + sorted_len, fwd_order = torch.sort( + lens.contiguous().view(-1), 0, descending=True + ) + _, bwd_order = torch.sort(fwd_order) + sorted_len = list(sorted_len) + return sorted_len, fwd_order, bwd_order + +class FeatureAggregation(nn.Module): + def __init__(self, module_dim=512): + super(FeatureAggregation, self).__init__() + self.module_dim = module_dim + + self.q_proj = nn.Linear(module_dim, module_dim, bias=False) + self.v_proj = nn.Linear(module_dim, module_dim, bias=False) + + self.cat = nn.Linear(2 * module_dim, module_dim) + self.attn = nn.Linear(module_dim, 1) + + self.activation = nn.ELU() + self.dropout = nn.Dropout(0.15) + + def forward(self, question_rep, visual_feat): + visual_feat = self.dropout(visual_feat) + q_proj = self.q_proj(question_rep) + v_proj = self.v_proj(visual_feat) + + v_q_cat = torch.cat((v_proj, q_proj.unsqueeze(1) * v_proj), dim=-1) + v_q_cat = self.cat(v_q_cat) + v_q_cat = self.activation(v_q_cat) + + attn = self.attn(v_q_cat) # (bz, k, 1) + attn = F.softmax(attn, dim=1) # (bz, k, 1) + + v_distill = (attn * visual_feat).sum(1) + + return v_distill + + +class InputUnitLinguistic(nn.Module): + def __init__(self, vocab_size, wordvec_dim=300, rnn_dim=512, module_dim=512, bidirectional=True): + super(InputUnitLinguistic, self).__init__() + + self.dim = module_dim + + self.bidirectional = bidirectional + if bidirectional: + rnn_dim = rnn_dim // 2 + + self.encoder_embed = nn.Embedding(vocab_size, wordvec_dim) + self.tanh = nn.Tanh() + self.encoder = nn.LSTM(wordvec_dim, rnn_dim, batch_first=True, bidirectional=bidirectional) + self.embedding_dropout = nn.Dropout(p=0.15) + self.question_dropout = nn.Dropout(p=0.18) + + self.module_dim = module_dim + + def forward(self, questions, question_len): + """ + Args: + question: [Tensor] (batch_size, max_question_length) + question_len: [Tensor] (batch_size) + return: + question representation [Tensor] (batch_size, module_dim) + """ + questions_embedding = self.encoder_embed(questions) # (batch_size, seq_len, dim_word) + embed = self.tanh(self.embedding_dropout(questions_embedding)) + + embed = nn.utils.rnn.pack_padded_sequence(embed, question_len.cpu(), batch_first=True, + enforce_sorted=False) + + self.encoder.flatten_parameters() + _, (question_embedding, _) = self.encoder(embed) + if self.bidirectional: + question_embedding = torch.cat([question_embedding[0], question_embedding[1]], -1) + question_embedding = self.question_dropout(question_embedding) + + return question_embedding + +class captionLinguistic(nn.Module): + def __init__(self, caption_dim, rnn_dim=512, module_dim=512, bidirectional=True): + super(captionLinguistic, self).__init__() + + self.dim = module_dim + + self.bidirectional = bidirectional + if bidirectional: + rnn_dim = rnn_dim // 2 + + self.encoder_embed = nn.Linear(caption_dim, rnn_dim) + self.tanh = nn.Tanh() + self.encoder = nn.LSTM(caption_dim, rnn_dim, batch_first=True, bidirectional=bidirectional) + # self.encoder = DynamicRNN(self.encoder) + self.embedding_dropout = nn.Dropout(p=0.15) + self.caption_dropout = nn.Dropout(p=0.18) + + self.module_dim = module_dim + + def forward(self, caption, caption_len): + """ + Args: + caption: [Tensor] (batch_size, max_question_length, cap_dim) + caption_len: [Tensor] (batch_size,) + return: + caption representation [Tensor] (batch_size, module_dim) + """ + bs, max_len, rnn_dim = caption.size() + #caption_embedding = self.encoder_embed(caption) # (batch_size, num_cap, max_question_length, rnn_dim) + caption_embedding = caption + embed = self.tanh(self.embedding_dropout(caption_embedding)) + + #for i in range(num_cap): + + embed_candi = nn.utils.rnn.pack_padded_sequence(embed, caption_len.cpu(), batch_first=True, + enforce_sorted=False) + self.encoder.flatten_parameters() + _, (caption_embedding, _) = self.encoder(embed_candi) + if self.bidirectional: + caption_embedding = torch.cat([caption_embedding[0], caption_embedding[1]], -1) + caption_embedding = self.caption_dropout(caption_embedding) + + return caption_embedding + + +class RetrieveNetwork(nn.Module): + def __init__(self, vision_dim, module_dim, word_dim, k_max_frame_level, k_max_clip_level, spl_resolution, vocab, question_type, caption_dim, cap_vocab): + super(RetrieveNetwork, self).__init__() + + self.question_type = question_type + self.feature_aggregation = FeatureAggregation(module_dim) + + if self.question_type in ['action', 'transition']: + encoder_vocab_size = len(vocab['question_answer_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + elif self.question_type == 'count': + encoder_vocab_size = len(vocab['question_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + else: + encoder_vocab_size = len(vocab['question_token_to_idx']) + self.num_classes = len(vocab['answer_token_to_idx']) + self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, + module_dim=module_dim, rnn_dim=module_dim) + ######################################## + if cap_vocab is not None: + self.word_embedding = nn.Embedding(len(cap_vocab), word_dim) + + self.caption_unit = captionLinguistic(caption_dim=caption_dim) + self.appearance_W = nn.Sequential( + nn.Linear(vision_dim,module_dim), + nn.Dropout(p=0.1), + nn.Tanh() + ) + self.appearance_ATT = nn.Sequential( + nn.Linear(module_dim,1), + nn.Softmax(dim=-1) + ) + self.motion_W = nn.Sequential( + nn.Linear(vision_dim,module_dim), + nn.Dropout(p=0.1), + nn.Tanh() + ) + self.motion_ATT = nn.Sequential( + nn.Linear(module_dim,1), + nn.Softmax(dim=-1) + ) + self.fusion = nn.Sequential( + nn.Linear(module_dim*4,module_dim*2), + nn.Dropout(p=0.2), + nn.Tanh() + ) + self.softmax = nn.Softmax(dim=-1) + ######################################## + init_modules(self.modules(), w_init="xavier_uniform") + if cap_vocab is not None: + nn.init.uniform_(self.word_embedding.weight, -1.0, 1.0) + + def forward(self, video_appearance_feat, video_motion_feat, caption, caption_len, question=None, question_len=None): + """ + Args: + video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) + video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) + caption: [Tensor] (batch_size/num_cap num_word cap_dim) | [tgif] (batch_size/num_cap num_word) + caption_len: [Tensor] (batch_size/num_cap) | [tgif] (batch_size/num_cap) + question: [Tensor] (batch_size num_word) + question_len [Tensor] (batch_size) + return: + similarity matrix. + """ + + batch_size = video_appearance_feat.size(0) + _,_,num_frames,_ = video_appearance_feat.size() + # video -> caption + caption_embedding = caption + + if self.question_type != 'none': + + caption_embedding = self.word_embedding(caption) # batch_size num_word embedding + + caption_sentence = self.caption_unit(caption_embedding,caption_len) # batch_size module_dim + + appearance_feat = self.appearance_W(video_appearance_feat) # batch_size num_clips num_frames module_dim + + motion_feat = self.motion_W(video_motion_feat) # batch_size num_K module_dim + + #appearance_feat = torch.sum(motion_feat.unsqueeze(-2).repeat(1,1,num_frames,1)*appearance_feat,dim=-2) # batch_size num_K module_dim + appearance_feat = torch.mean(appearance_feat,dim=-2).squeeze() # batch_size num_clips module_dim + appearance_ATT = self.appearance_ATT(appearance_feat) # batch_size num_K 1 + motion_ATT = self.motion_ATT(motion_feat) # batch_size num_K 1 + + appearance = torch.bmm(appearance_ATT.permute(0,2,1),appearance_feat).squeeze(-2) # batch_size module_dim + motion = torch.bmm(motion_ATT.permute(0,2,1),motion_feat).squeeze(-2) # batch_size module_dim + + v = ((appearance + motion)/2) + inner_prod = v.mm(caption_sentence.t()) # video cap + + im_norm = torch.sqrt((v**2).sum(1).view(-1, 1) + 1e-18) + s_norm = torch.sqrt((caption_sentence**2).sum(1).view(1, -1) + 1e-18) + sim = inner_prod / (im_norm * s_norm) + # question -> caption + + if question is not None: + + question_embedding = self.linguistic_input_unit(question, question_len) # (batch_size, module_dim) + q_c_sim = question_embedding.mm(caption_sentence.t()) # ques cap + que_norm_ = torch.sqrt((q_c_sim**2).sum(1).view(-1, 1) + 1e-18) + q_c_sim = q_c_sim / (que_norm_ * s_norm) + + return sim + q_c_sim, caption_sentence + # return sim, caption_sentence + +class ContrastiveLoss(nn.Module): + '''compute contrastive loss + ''' + def __init__(self, margin=0, max_violation=False, direction='bi', topk=1): + '''Args: + direction: i2t for negative sentence, t2i for negative image, bi for both + ''' + super(ContrastiveLoss, self).__init__() + self.margin = margin + self.max_violation = max_violation + self.direction = direction + self.topk = topk + + def forward(self, scores, margin=None, average_batch=True): + ''' + Args: + scores: image-sentence score matrix, (batch, batch) + the same row of im and s are positive pairs, different rows are negative pairs + ''' + + if margin is None: + margin = self.margin + + batch_size = scores.size(0) + diagonal = scores.diag().view(batch_size, 1) # positive pairs + + # mask to clear diagonals which are positive pairs + pos_masks = torch.eye(batch_size).bool().to(scores.device) + + batch_topk = min(batch_size, self.topk) + if self.direction == 'i2t' or self.direction == 'bi': + d1 = diagonal.expand_as(scores) # same collumn for im2s (negative sentence) + # compare every diagonal score to scores in its collumn + # caption retrieval + cost_s = (margin + scores - d1).clamp(min=0) + cost_s = cost_s.masked_fill(pos_masks, 0) + if self.max_violation: + cost_s, _ = torch.topk(cost_s, batch_topk, dim=1) + cost_s = cost_s / batch_topk + if average_batch: + cost_s = cost_s / batch_size + else: + if average_batch: + cost_s = cost_s / (batch_size * (batch_size - 1)) + cost_s = torch.sum(cost_s) + + if self.direction == 't2i' or self.direction == 'bi': + d2 = diagonal.t().expand_as(scores) # same row for s2im (negative image) + # compare every diagonal score to scores in its row + cost_im = (margin + scores - d2).clamp(min=0) + cost_im = cost_im.masked_fill(pos_masks, 0) + if self.max_violation: + cost_im, _ = torch.topk(cost_im, batch_topk, dim=0) + cost_im = cost_im / batch_topk + if average_batch: + cost_im = cost_im / batch_size + else: + if average_batch: + cost_im = cost_im / (batch_size * (batch_size - 1)) + cost_im = torch.sum(cost_im) + + if self.direction == 'i2t': + return cost_s + elif self.direction == 't2i': + return cost_im + else: + return cost_s + cost_im \ No newline at end of file diff --git a/model/utils.py b/model/utils.py new file mode 100644 index 0000000..02e343f --- /dev/null +++ b/model/utils.py @@ -0,0 +1,31 @@ +from torch.nn import init +import torch +import torch.nn as nn + + +def init_modules(modules, w_init='kaiming_uniform'): + if w_init == "normal": + _init = init.normal_ + elif w_init == "xavier_normal": + _init = init.xavier_normal_ + elif w_init == "xavier_uniform": + _init = init.xavier_uniform_ + elif w_init == "kaiming_normal": + _init = init.kaiming_normal_ + elif w_init == "kaiming_uniform": + _init = init.kaiming_uniform_ + elif w_init == "orthogonal": + _init = init.orthogonal_ + else: + raise NotImplementedError + for m in modules: + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + _init(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + if isinstance(m, (nn.LSTM, nn.GRU)): + for name, param in m.named_parameters(): + if 'bias' in name: + nn.init.zeros_(param) + elif 'weight' in name: + _init(param) diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..5d4f363 --- /dev/null +++ b/readme.md @@ -0,0 +1,117 @@ +## Video Question Answering with Prior Knowledge and Object-sensitive Learning (PKOL) + +[![](https://img.shields.io/badge/python-3.7.11-orange.svg?style=for-the-badge)](https://www.python.org/) [![](https://img.shields.io/apm/l/vim-mode.svg?style=for-the-badge)](https://github.com/zchoi/S2-Transformer/blob/main/LICENSE) [![](https://img.shields.io/badge/Pytorch-1.7.0-orange?style=for-the-badge)](https://pytorch.org/) + +This is the official code implementation for the paper: + +[Video Question Answering with Prior Knowledge and Object-sensitive Learning]() + +

+ Relationship-Sensitive Transformer +

+ + +## Table of Contents + +- [Setups](#Setups) +- [Data Preparation](#data-preparation) +- [Training](#training) +- [Evaluation](#evaluation) +- [Reference and Citation](#reference-and-citation) +- [Acknowledgements](#acknowledgements) + +## Setups + +- **Ubuntu** 20.04 +- **CUDA** 11.5 +- **Python** 3.7 +- **PyTorch** 1.7.0 + cu110 + +1. Clone this repository: + +``` +git clone https://github.com/zchoi/PKOL.git +``` + +2. Install dependencies: + +``` +conda create -n vqa python=3.7 +conda activate vqa +pip install -r requirements.txt +``` +## Data Preparation + +- ### Text Features + + Download pre-extracted text features from [here](), and place it into `data/{dataset}-qa/` for MSVD-QA, MSRVTT-QA and `data/tgif-qa/{question_type}/` for TGIF-QA, respectively. + +- ### Visual Features + + Download pre-extracted visual features (i.e., appearance, motion, object) from [here](), and place it into `data/{dataset}-qa/` for MSVD-QA, MSRVTT-QA and `data/tgif-qa/{question_type}/` for TGIF-QA, respectively. + +> **Note:** The object features are huge, (especially ~700GB for TGIF-QA), please be cautious of disk space when downloading. + +## Experiments + +### For MSVD-QA and MSRVTT-QA: + +Training: + +``` +python train_iterative.py --cfg configs/msvd_qa.yml +``` +Evaluation: + +``` +python validate_iterative.py --cfg configs/msvd_qa.yml +``` +### For TGIF-QA: + + Choose a suitable config file in `configs/{task}.yml` for one of 4 tasks: `action, transition, count, frameqa` to train/val the model. For example, to train with action task, run the following command: + +Training: + +``` +python train_iterative.py --cfg configs/tgif_qa_action.yml +``` + +Evaluation: + +``` +python validate_iterative.py --cfg configs/tgif_qa_action.yml +``` +## Results + +Performance on MSVD-QA and MSRVTT-QA datasets: + +| Model | MSVD-QA | MSRVTT-QA | +|:---------- |:-------: |:-: | +| PKOL | 41.1 | 36.9 | + +Performance on TGIF-QA dataset: + +| Model | Count ↓ | FrameQA ↑ | Trans. ↑ | Action ↑ | +| :---- | :-----: | :-------: | :------: | :------: | +| PKOL | 3.67 | 61.8 | 82.8 | 74.6 | + +## Reference +[1] Le, Thao Minh, et al. "Hierarchical conditional relation networks for video question answering." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020. + +## Citation +``` +@inproceedings{PKOL, + author = {Pengpeng Zeng and + Haonan Zhang and + Lianli Gao and + Jingkuan Song and + Heng Tao Shen + }, + title = {Video Question Answering with Prior Knowledge and Object-sensitive Learning}, + booktitle = {TIP}, + % pages = {????--????} + year = {2022} +} +``` +## Acknowledgements +Our code implementation is based on this [repo](https://github.com/thaolmk54/hcrn-videoqa). \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5fc24e3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,87 @@ +absl-py==0.12.0 +addict==2.4.0 +albumentations==0.5.2 +antlr4-python3-runtime==4.8 +appdirs==1.4.4 +cachetools==4.2.2 +certifi==2020.12.5 +cffi==1.14.5 +chardet==4.0.0 +cityscapesScripts==2.2.0 +coloredlogs==15.0 +cycler==0.10.0 +Cython==0.29.22 +dataclasses==0.6 +decorator==4.4.2 +easydict==1.9 +fairseq==0.10.2 +future==0.18.2 +google-auth==1.30.0 +google-auth-oauthlib==0.4.4 +grpcio==1.37.1 +h5py==3.7.0 +humanfriendly==9.1 +Hydra==2.5 +hydra-core==1.0.6 +idna==2.10 +imagecorruptions==1.1.2 +imageio==2.9.0 +imgaug==0.4.0 +importlib-metadata==4.0.1 +importlib-resources==5.1.2 +install==1.3.4 +joblib==1.0.1 +kiwisolver==1.3.1 +Markdown==3.3.4 +matplotlib==3.4.1 +mmcv-full==1.2.5 +mmlvis==10.5.3 +mmpycocotools==12.0.3 +networkx==2.5.1 +numpy==1.20.2 +oauthlib==3.1.0 +omegaconf==2.0.6 +opencv-python==4.5.1.48 +opencv-python-headless==4.5.1.48 +pandas==1.1.5 +Pillow==8.2.0 +portalocker==2.0.0 +protobuf==3.16.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycparser==2.20 +pyparsing==2.4.7 +pyquaternion==0.9.9 +python-dateutil==2.8.1 +PyWavelets==1.1.1 +PyYAML==5.4.1 +regex==2021.4.4 +requests==2.25.1 +requests-oauthlib==1.3.0 +rsa==4.7.2 +sacrebleu==1.5.1 +scikit-image==0.18.1 +scikit-learn==0.24.1 +scipy==1.6.2 +setproctitle==1.2.2 +Shapely==1.7.1 +six==1.15.0 +sklearn==0.0 +tensorboard==2.5.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +termcolor==1.1.0 +terminaltables==3.1.0 +threadpoolctl==2.1.0 +tifffile==2021.3.31 +torch==1.7.0+cu110 +torch-tb-profiler==0.1.0 +torchaudio==0.7.0 +torchvision==0.8.0 +tqdm==4.60.0 +typing==3.7.4.3 +typing-extensions==3.7.4.3 +urllib3==1.26.4 +Werkzeug==1.0.1 +yapf==0.31.0 +zipp==3.4.1 diff --git a/train_iterative.py b/train_iterative.py new file mode 100644 index 0000000..97dce93 --- /dev/null +++ b/train_iterative.py @@ -0,0 +1,540 @@ +import json +import os, sys +import torch +import torch.optim as optim +import torch.nn as nn +import numpy as np +import argparse +import time +import logging +import model.PKOL as PKOL + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') +logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') +rootLogger = logging.getLogger() + + +from Dataloder_iterative import VideoQADataLoader +from utils import todevice +from validate_iterative import validate +from model.retrieve_model import RetrieveNetwork +from utils import todevice +from termcolor import colored +from config import cfg, cfg_from_file + + +def margin_ranking_loss( + similary_matrix, + margin=None, + direction= 'both', + average_batch = True, + video_name = None, + video_idx2_cap_gt=None + ): + + batch_size = similary_matrix.size(0) + diagonal = similary_matrix.diag().view(batch_size, 1) + + pos_mask = torch.eye(batch_size,batch_size,device=similary_matrix.device).bool() + + # v2c + if direction == 'both' or direction == 'v2c': + diagonal_1 = diagonal.expand_as(similary_matrix) + + cost_cap = (margin + similary_matrix - diagonal_1).clamp(min=0) + cost_cap = cost_cap.masked_fill(pos_mask, 0) + if average_batch: + cost_cap = cost_cap / (batch_size * (batch_size - 1)) + cost_cap = torch.sum(cost_cap) + + # c2v + if direction == 'both' or direction == 'c2v': + diagonal_2 = diagonal.t().expand_as(similary_matrix) + cost_vid = (margin + similary_matrix - diagonal_2).clamp(min=0) + cost_vid = cost_vid.masked_fill(pos_mask,0) + if average_batch: + cost_vid = cost_vid / (batch_size * (batch_size - 1)) + cost_vid = torch.sum(cost_vid) + + if direction == 'both': + return cost_cap + cost_vid + elif direction == 'v2c': + return cost_cap + else: + return cost_vid + +def train(cfg): + logging.info("Create train_loader and val_loader.........") + train_loader_kwargs = { + 'split' : 'train', + 'name' : cfg.dataset.name, + 'caption_max_num' : cfg.dataset.max_cap_num, + 'question_type': cfg.dataset.question_type, + 'question_pt': cfg.dataset.train_question_pt, + 'vocab_json': cfg.dataset.vocab_json, + 'appearance_feat': cfg.dataset.appearance_feat, + 'motion_feat': cfg.dataset.motion_feat, + 'object_feat' : cfg.dataset.object_feat, + 'train_num': cfg.train.train_num, + 'batch_size': cfg.train.batch_size, + 'num_workers': cfg.num_workers, + 'shuffle': True, + 'pin_memory': True + + } + train_loader = VideoQADataLoader(**train_loader_kwargs) + + logging.info("number of train instances: {}".format(len(train_loader.dataset))) + if cfg.val.flag: + val_loader_kwargs = { + 'split' : 'val', + 'name' : cfg.dataset.name, + 'caption_max_num' : cfg.dataset.max_cap_num, + 'question_type': cfg.dataset.question_type, + 'question_pt': cfg.dataset.val_question_pt, + 'vocab_json': cfg.dataset.vocab_json, + 'appearance_feat': cfg.dataset.appearance_feat, + 'motion_feat': cfg.dataset.motion_feat, + 'object_feat' : cfg.dataset.object_feat, + 'val_num': cfg.val.val_num, + 'batch_size': cfg.train.batch_size, + 'num_workers': cfg.num_workers, + 'shuffle': False, + 'pin_memory': True + + } + val_loader = VideoQADataLoader(**val_loader_kwargs) + + logging.info("number of val instances: {}".format(len(val_loader.dataset))) + + logging.info("Create model.........") + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + with open('data/tgif-qa/tgif-caption/tgif_cap_index.json','r') as f: + f = json.load(f) + tgif_cap = f if cfg.dataset.question_type != 'none' else None + + model_kwargs = { + 'vision_dim': cfg.train.vision_dim, + 'module_dim': cfg.train.module_dim, + 'word_dim': cfg.train.word_dim, + 'k_max_frame_level': cfg.train.k_max_frame_level, + 'k_max_clip_level': cfg.train.k_max_clip_level, + 'spl_resolution': cfg.train.spl_resolution, + 'vocab': train_loader.vocab, + 'question_type': cfg.dataset.question_type, + 'caption_dim' : cfg.train.caption_dim, + 'topk' : cfg.train.topk, + 'corpus' : None, + 'corpus_len' : None, + 'patch_number' : cfg.train.patch_number, + 'cap_vocab' : tgif_cap if cfg.dataset.question_type != 'none' else None, + 'visualization' : False + } + model_kwargs_tosave = {k: v for k, v in model_kwargs.items() if k != 'vocab'} + model = PKOL.PKOL_Net(**model_kwargs).to(device) + + retrieve_model_kwargs = { + 'vision_dim': cfg.train.vision_dim, + 'module_dim': cfg.train.module_dim, + 'word_dim': cfg.train.word_dim, + 'k_max_frame_level': cfg.train.k_max_frame_level, + 'k_max_clip_level': cfg.train.k_max_clip_level, + 'spl_resolution': cfg.train.spl_resolution, + 'vocab': train_loader.vocab, + 'question_type': cfg.dataset.question_type, + 'caption_dim' : cfg.train.caption_dim, + 'cap_vocab' : tgif_cap if cfg.dataset.question_type != 'none' else None + } + model_retrieval_kwargs_tosave = {k: v for k, v in retrieve_model_kwargs.items() if k != 'vocab'} + retrieve_model = RetrieveNetwork(**retrieve_model_kwargs).to(device) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + pytorch_total_params_R = sum(p.numel() for p in retrieve_model.parameters() if p.requires_grad) + + logging.info('top-k trained model: {}'.format(cfg.train.topk)) + logging.info('num of params: {}'.format(pytorch_total_params + pytorch_total_params_R)) + #logging.info(model) + + if cfg.train.glove: + logging.info('load glove vectors') + train_loader.glove_matrix = torch.FloatTensor(train_loader.glove_matrix).to(device) + with torch.no_grad(): + model.linguistic_input_unit.encoder_embed.weight.set_(train_loader.glove_matrix) + retrieve_model.linguistic_input_unit.encoder_embed.weight.set_(train_loader.glove_matrix) + if cfg.dataset.question_type != 'none': + retrieve_model.word_embedding.data = torch.from_numpy(np.load('data/tgif-qa/tgif-caption/glove.npy')).to(device) + + if torch.cuda.device_count() > 1 and cfg.multi_gpus: + model = model.cuda() + logging.info("Using {} GPUs".format(torch.cuda.device_count())) + model = nn.DataParallel(model, device_ids=None) + + optimizer = optim.Adam([{'params': model.parameters()},{'params': retrieve_model.parameters()}], cfg.train.lr) + + start_epoch = 0 + + if cfg.dataset.question_type == 'count': + best_val = 100.0 + else: + best_val = 0 + best_retrieval = 0 + if cfg.train.restore: + print("Restore checkpoint and optimizer...") + ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') + ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) + start_epoch = ckpt['epoch'] + 1 + model.load_state_dict(ckpt['state_dict']) + optimizer.load_state_dict(ckpt['optimizer']) + + ckpt_retrieval = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model_retrieval.pt') + retrieve_model.load_state_dict(ckpt_retrieval['state_dict']) + + if cfg.dataset.question_type in ['frameqa', 'none']: + criterion = nn.CrossEntropyLoss().to(device) + elif cfg.dataset.question_type == 'count': + criterion = nn.MSELoss().to(device) + + logging.info("Start training........") + for epoch in range(start_epoch, cfg.train.max_epochs): + logging.info('>>>>>> epoch {} <<<<<<'.format(epoch)) + if epoch < 0: + retrieve_model.train() + count = 0 + batch_mse_sum = 0.0 + total_loss, avg_loss = 0.0, 0.0 + avg_loss = 0 + + for i, batch in enumerate(iter(train_loader)): + progress = epoch + i / len(train_loader) + video_idx, question_idx, answers, ans_candidates, ans_candidates_len, appearance_feat, motion_feat, _, question,\ + question_len, caption, caption_len = [todevice(x, device) for x in batch] + batch_size = appearance_feat.size(0) + optimizer.zero_grad() + sim_matrix, _ = retrieve_model(appearance_feat, motion_feat, caption, caption_len, question, question_len) + + loss = margin_ranking_loss(sim_matrix,0.2) + loss.backward() + total_loss += loss.detach() + avg_loss = total_loss / (i + 1) + nn.utils.clip_grad_norm_(retrieve_model.parameters(), max_norm=12) + optimizer.step() + + sys.stdout.write( + "\rProgress = {progress} avg_loss = {avg_loss} exp: {exp_name}".format( + progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), + avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), + exp_name=cfg.exp_name)) + + sys.stdout.flush() + sys.stdout.write("\n") + + if (epoch + 1) % 10 == 0: + optimizer = step_decay(cfg, optimizer) + sys.stdout.flush() + torch.cuda.empty_cache() + else: + model.train() + retrieve_model.train() + total_acc, count = 0, 0 + batch_mse_sum = 0.0 + total_loss, avg_loss = 0.0, 0.0 + avg_loss = 0 + train_accuracy = 0 + for i, batch in enumerate(iter(train_loader)): + progress = epoch + i / len(train_loader) + video_idx, question_idx, answers, ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question,\ + question_len, caption, caption_len = [todevice(x, device) for x in batch] + + answers = answers.cuda().squeeze() + batch_size = answers.size(0) + optimizer.zero_grad() + + sim, _ = retrieve_model(appearance_feat, motion_feat, caption, caption_len, question, question_len) + + with torch.no_grad(): + sim_list = [] + cap_list = [] + patch_num = cfg.train.patch_number # 40000 -msrvtt 35000 -msvd + chunk = train_loader.dataset.caption_pool.size(0) // patch_num #1 + left = train_loader.dataset.caption_pool.size(0) % patch_num #22239 + j = 0 + for j in range(chunk): + cap = train_loader.dataset.caption_pool[j*patch_num:(j+1)*patch_num].to(appearance_feat.device) + cap_len = train_loader.dataset.caption_pool_len[j*patch_num:(j+1)*patch_num].to(appearance_feat.device) + similiry_j, caption_tensor_j = retrieve_model( # batch_size patch_num / patch_num module_dim + appearance_feat, + motion_feat, + cap, + cap_len, + question, + question_len + ) + sim_list.append(similiry_j) + cap_list.append(caption_tensor_j) + + j = j+1 if chunk else j + if left: + cap = train_loader.dataset.caption_pool[j*patch_num:].to(appearance_feat.device) + cap_len = train_loader.dataset.caption_pool_len[j*patch_num:].to(appearance_feat.device) + similiry_j, caption_tensor_j = retrieve_model( # batch_size left / left module_dim + appearance_feat, + motion_feat, + cap, + cap_len, + question, + question_len) + + sim_list.append(similiry_j) + cap_list.append(caption_tensor_j) + similiry_matrix = torch.cat(sim_list, dim=-1) + caption_tensor = torch.cat(cap_list, dim=0) + + logits = model(ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question, + question_len, similarity=similiry_matrix, corpus=caption_tensor) # batch_size batch_size + + if cfg.dataset.question_type in ['action', 'transition']: + batch_agg = np.concatenate(np.tile(np.arange(batch_size).reshape([batch_size, 1]), + [1, 5])) * 5 # [0, 0, 0, 0, 0, 5, 5, 5, 5, 1, ...] + answers_agg = tile(answers, 0, 5) + loss = torch.max(torch.tensor(0.0).cuda(), + 1.0 + logits - logits[answers_agg + torch.from_numpy(batch_agg).cuda()]) + loss = loss.sum() + if cfg.train.joint: + r_loss = margin_ranking_loss( + similary_matrix = sim, + margin=0.2 + ) + loss += r_loss + loss.backward() + total_loss += loss.detach() + avg_loss = total_loss / (i + 1) + nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) + optimizer.step() + preds = torch.argmax(logits.view(batch_size, 5), dim=1) + aggreeings = (preds == answers) + elif cfg.dataset.question_type == 'count': + answers = answers.unsqueeze(-1) + loss = criterion(logits, answers.float()) + if cfg.train.joint: + r_loss = margin_ranking_loss( + similary_matrix = sim, + margin=0.2 + ) + loss += r_loss + loss.backward() + total_loss += loss.detach() + avg_loss = total_loss / (i + 1) + nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) + optimizer.step() + preds = (logits + 0.5).long().clamp(min=1, max=10) + batch_mse = (preds - answers) ** 2 + else: + loss = criterion(logits, answers) + if cfg.train.joint: + r_loss = margin_ranking_loss( + similary_matrix = sim, + margin=0.2 + ) + loss += r_loss + loss.backward() + total_loss += loss.detach() + avg_loss = total_loss / (i + 1) + nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) + optimizer.step() + aggreeings = batch_accuracy(logits, answers) + + if cfg.dataset.question_type == 'count': + batch_avg_mse = batch_mse.sum().item() / answers.size(0) + batch_mse_sum += batch_mse.sum().item() + count += answers.size(0) + avg_mse = batch_mse_sum / count + sys.stdout.write( + "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_mse = {train_mse} avg_mse = {avg_mse} exp: {exp_name}".format( + progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), + ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), + avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), + train_mse=colored("{:.4f}".format(batch_avg_mse), "blue", + attrs=['bold']), + avg_mse=colored("{:.4f}".format(avg_mse), "red", attrs=['bold']), + exp_name=cfg.exp_name)) + sys.stdout.flush() + else: + total_acc += aggreeings.sum().item() + count += answers.size(0) + train_accuracy = total_acc / count + if not cfg.train.joint: + sys.stdout.write( + "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_acc = {train_acc} avg_acc = {avg_acc} exp: {exp_name}".format( + progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), + ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), + avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), + train_acc=colored("{:.4f}".format(aggreeings.float().mean().cpu().numpy()), "blue", + attrs=['bold']), + avg_acc=colored("{:.4f}".format(train_accuracy), "red", attrs=['bold']), + exp_name=cfg.exp_name)) + else: + sys.stdout.write( + "\rProgress = {progress} ce_loss = {ce_loss} re_loss = {re_loss} avg_loss = {avg_loss} train_acc = {train_acc} avg_acc = {avg_acc} exp: {exp_name}".format( + progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), + ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), + re_loss=colored("{:.4f}".format(r_loss.item()), "blue", attrs=['bold']), + avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), + train_acc=colored("{:.4f}".format(aggreeings.float().mean().cpu().numpy()), "blue", + attrs=['bold']), + avg_acc=colored("{:.4f}".format(train_accuracy), "red", attrs=['bold']), + exp_name=cfg.exp_name)) + sys.stdout.flush() + + sys.stdout.write("\n") + if cfg.dataset.question_type == 'count': + if (epoch + 1) % 5 == 0: + optimizer = step_decay(cfg, optimizer) + else: + if (epoch + 1) % 10 == 0: + optimizer = step_decay(cfg, optimizer) + sys.stdout.flush() + + logging.info("Epoch = %s avg_loss = %.3f avg_acc = %.3f" % (epoch, avg_loss, train_accuracy)) + if cfg.val.flag: + output_dir = os.path.join(cfg.dataset.save_dir, 'preds') + if not os.path.exists(output_dir): + os.makedirs(output_dir) + else: + assert os.path.isdir(output_dir) + valid_acc, _, _, r10 = validate(cfg, model, retrieve_model, val_loader, device, write_preds=False) + if (valid_acc > best_val and cfg.dataset.question_type != 'count') or (valid_acc < best_val and cfg.dataset.question_type == 'count'): + best_val = valid_acc + # Save best model + ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt') + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + else: + assert os.path.isdir(ckpt_dir) + save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(ckpt_dir, 'model.pt')) + save_checkpoint(epoch, retrieve_model, optimizer, model_retrieval_kwargs_tosave, os.path.join(ckpt_dir, 'model_retrieval.pt')) + sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir)) + sys.stdout.flush() + + logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc) + sys.stdout.write('~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format( + valid_acc=colored("{:.4f}".format(valid_acc), "red", attrs=['bold']))) + sys.stdout.flush() + +# Credit https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/4 +def tile(a, dim, n_tile): + init_dim = a.size(dim) + repeat_idx = [1] * a.dim() + repeat_idx[dim] = n_tile + a = a.repeat(*(repeat_idx)) + order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda() + return torch.index_select(a, dim, order_index) + + +def step_decay(cfg, optimizer): + # compute the new learning rate based on decay rate + cfg.train.lr *= 0.5 + logging.info("Reduced learning rate to {}".format(cfg.train.lr)) + sys.stdout.flush() + for param_group in optimizer.param_groups: + param_group['lr'] = cfg.train.lr + + return optimizer + + +def batch_accuracy(predicted, true): + """ Compute the accuracies for a batch of predictions and answers """ + predicted = predicted.detach().argmax(1) + agreeing = (predicted == true) + return agreeing + + +def save_checkpoint(epoch, model, optimizer, model_kwargs, filename): + state = { + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'model_kwargs': model_kwargs, + } + time.sleep(10) + torch.save(state, filename) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='msvd_qa.yml', type=str) + args = parser.parse_args() + if args.cfg_file is not None: + cfg_from_file(args.cfg_file) + + assert cfg.dataset.name in ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] + assert cfg.dataset.question_type in ['frameqa', 'count', 'transition', 'action', 'none'] + # check if the data folder exists + assert os.path.exists(cfg.dataset.data_dir) + # check if k_max is set correctly + assert cfg.train.k_max_frame_level <= 16 + assert cfg.train.k_max_clip_level <= 8 + + + if not cfg.multi_gpus: + torch.cuda.set_device(cfg.gpu_id) + + cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name) + if not os.path.exists(cfg.dataset.save_dir): + os.makedirs(cfg.dataset.save_dir) + else: + assert os.path.isdir(cfg.dataset.save_dir) + log_file = os.path.join(cfg.dataset.save_dir, "log") + + if not cfg.train.restore and not os.path.exists(log_file): + os.mkdir(log_file) + + fileHandler = logging.FileHandler(os.path.join(log_file, 'stdout.log'), 'w+') + fileHandler.setFormatter(logFormatter) + rootLogger.addHandler(fileHandler) + # args display + for k, v in vars(cfg).items(): + logging.info(k + ':' + str(v)) + # concat absolute path of input files + + if cfg.dataset.name == 'tgif-qa': + cfg.dataset.train_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.train_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.val_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.val_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name, cfg.dataset.question_type)) + + cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.object_feat = '/mnt/hdd2/zhanghaonan/object_features.h5' + else: + cfg.dataset.question_type = 'none' + cfg.dataset.appearance_feat = '{}_appearance_feat.h5' + cfg.dataset.motion_feat = '{}_motion_feat.h5' + cfg.dataset.object_feat = '{}_object_feat.h5' + cfg.dataset.vocab_json = '{}_vocab.json' + cfg.dataset.train_question_pt = '{}_train_questions.pt' + cfg.dataset.val_question_pt = '{}_val_questions.pt' + cfg.dataset.train_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.train_question_pt.format(cfg.dataset.name)) + cfg.dataset.val_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.val_question_pt.format(cfg.dataset.name)) + cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name)) + + cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) + cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) + cfg.dataset.object_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.object_feat.format(cfg.dataset.name)) + + # set random seed + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(cfg.seed) + + train(cfg) + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..3cc53b5 --- /dev/null +++ b/utils.py @@ -0,0 +1,8 @@ +import torch + +def todevice(tensor, device): + if isinstance(tensor, list) or isinstance(tensor, tuple): + assert isinstance(tensor[0], torch.Tensor) + return [todevice(t, device) for t in tensor] + elif isinstance(tensor, torch.Tensor): + return tensor.to(device) \ No newline at end of file diff --git a/validate_iterative.py b/validate_iterative.py new file mode 100644 index 0000000..b7021ee --- /dev/null +++ b/validate_iterative.py @@ -0,0 +1,440 @@ +from email.policy import strict +import os +# os.environ["CUDA_VISIBLE_DEVICES"] = "3" +import torch +import logging +import numpy as np +from tqdm import tqdm +import argparse +import sys +import json +import pickle +from termcolor import colored +from model.retrieve_model import RetrieveNetwork +from Dataloder_iterative import VideoQADataLoader +from utils import todevice + +import model.PKOL as PKOL + +from config import cfg, cfg_from_file +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') +logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') +rootLogger = logging.getLogger() + +# msvd +# what 2 +# who 21 +# how 54 +# where 505 +# when 405 + +# msrvtt +# what 10 +# who 2 +# how 64 +# where 457 +# when 310 + +question_type_acc_msvd = {2:0,21:0,54:0,505:0,405:0} +question_type_total_length_msvd = {2:0,21:0,54:0,505:0,405:0} + +question_type_acc_msrvtt = {10:0,2:0,64:0,457:0,310:0} +question_type_total_length_msrvtt = {10:0,2:0,64:0,457:0,310:0} + +def validate(cfg, model, retrieve_model, data, device, write_preds=False, test = False): + model.eval() + retrieve_model.eval() + print('validating...') + total_acc, count = 0.0, 0 + all_preds = [] + gts = [] + v_ids = [] + q_ids = [] + video_idx2_cap_gt = data.dataset.video_idx2_cap_gt + + caption_pool = data.dataset.caption_pool + caption_pool_len = data.dataset.caption_pool_len + + model.topk = cfg.val.topk + + if test: + if cfg.dataset.name == 'msvd-qa': + with open('data/msvd-qa/ref_captions.json','r') as f: + raw_caption = json.load(f) + + logging.info('top-k validation model: {}'.format(cfg.val.topk)) + logging.info("num of retrieve captions: {}".format(caption_pool.size(0))) + logging.info('validation video length:{}'.format(len(video_idx2_cap_gt))) + # model.corpus = caption_pool + # model.corpus_len = caption_pool_len + video_name = [] + all_scores = [] + caption_visualization = {} + with torch.no_grad(): + for batch in tqdm(data, total=len(data)): + video_idx, question_idx, answers, ans_candidates, ans_candidates_len, appearance_feat,\ + motion_feat, object_feat, question, question_len = [todevice(x, device) for x in batch] + # caption_pool = caption_pool.to(device) + # caption_pool_len = caption_pool_len.to(device) + if cfg.train.batch_size == 1: + answers = answers.to(device) + else: + answers = answers.to(device).squeeze() + batch_size = motion_feat.size(0) + + with torch.no_grad(): + sim_list = [] + cap_list = [] + patch_num = cfg.train.patch_number # 40000 -msrvtt 35000 -msvd + chunk = data.dataset.caption_pool.size(0) // patch_num #1 + left = data.dataset.caption_pool.size(0) % patch_num #22239 + j = 0 + for j in range(chunk): + cap = data.dataset.caption_pool[j*patch_num:(j+1)*patch_num].to(appearance_feat.device) + cap_len = data.dataset.caption_pool_len[j*patch_num:(j+1)*patch_num].to(appearance_feat.device) + similiry_j, caption_tensor_j = retrieve_model( # batch_size patch_num / patch_num module_dim + appearance_feat, + motion_feat, + cap, + cap_len, + question, + question_len + ) + sim_list.append(similiry_j) + cap_list.append(caption_tensor_j) + + j = j+1 if chunk else j + if left: + cap = data.dataset.caption_pool[j*patch_num:].to(appearance_feat.device) + cap_len = data.dataset.caption_pool_len[j*patch_num:].to(appearance_feat.device) + similiry_j, caption_tensor_j = retrieve_model( # batch_size left / left module_dim + appearance_feat, + motion_feat, + cap, + cap_len, + question, + question_len + ) + + sim_list.append(similiry_j) + cap_list.append(caption_tensor_j) + sim = torch.cat(sim_list, dim=-1) + caption_tensor = torch.cat(cap_list, dim=0) + + # sim, caption_tensor = retrieve_model(appearance_feat, motion_feat, caption_pool, caption_pool_len, question, question_len) + if not test: + logits = model(ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question, + question_len, similarity=sim, corpus=caption_tensor) + else: + logits, index, caption_att = model(ans_candidates, ans_candidates_len, appearance_feat, motion_feat, object_feat, question, + question_len, similarity=sim, corpus=caption_tensor) + + all_scores.append(sim.data.cpu().numpy()) + video_name.extend(list(video_idx.data.cpu().numpy())) + + if cfg.dataset.question_type in ['action', 'transition']: + preds = torch.argmax(logits.view(batch_size, 5), dim=1) + agreeings = (preds == answers) + elif cfg.dataset.question_type == 'count': + answers = answers.unsqueeze(-1) + preds = (logits + 0.5).long().clamp(min=1, max=10) + batch_mse = (preds - answers) ** 2 + else: + preds = logits.detach().argmax(1) + agreeings = (preds == answers) + + if write_preds: + if cfg.dataset.question_type not in ['action', 'transition', 'count']: + preds = logits.argmax(1) + if cfg.dataset.question_type in ['action', 'transition']: + answer_vocab = data.vocab['question_answer_idx_to_token'] + else: + answer_vocab = data.vocab['answer_idx_to_token'] + for predict in preds: + if cfg.dataset.question_type in ['count', 'transition', 'action']: + all_preds.append(predict.item()) + else: + all_preds.append(answer_vocab[predict.item()]) + for gt in answers: + if cfg.dataset.question_type in ['count', 'transition', 'action']: + gts.append(gt.item()) + else: + gts.append(answer_vocab[gt.item()]) + for id in video_idx: + v_ids.append(id.cpu().numpy()) + for ques_id in question_idx: + q_ids.append(ques_id.cpu().numpy()) + + if cfg.dataset.question_type == 'count': + total_acc += batch_mse.float().sum().item() + count += answers.size(0) + else: + total_acc += agreeings.float().sum().item() + count += answers.size(0) + if cfg.dataset.name == 'msvd-qa': + for h in range(question.size(0)): + if agreeings[h]: + question_type_acc_msvd[question[h][0].item()] += 1 + question_type_total_length_msvd[question[h][0].item()] += 1 + elif cfg.dataset.name == 'msrvtt-qa': + for h in range(question.size(0)): + if agreeings[h]: + question_type_acc_msrvtt[question[h][0].item()] += 1 + question_type_total_length_msrvtt[question[h][0].item()] += 1 + if test: + vocab = data.vocab['question_idx_to_token'] + answer_vocab = data.vocab['answer_idx_to_token'] + dict = {} + with open(cfg.dataset.test_question_pt, 'rb') as f: + obj = pickle.load(f) + questions = obj['questions'] + org_v_ids = obj['video_ids'] + org_v_names = obj['video_names'] + org_q_ids = obj['question_id'] + + for idx in range(len(org_q_ids)): + dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx]] + for k, qid in enumerate(question_idx): + if answer_vocab[answers[k].item()] != answer_vocab[preds[k].item()]: #or answer_vocab[answers[k].item()]=='man' or answer_vocab[answers[k].item()]=='woman': + continue + for n, topk_i in enumerate(index[k]): + # caption_visualization.setdefault(qid.item(),[]).append((data.dataset.visualization[topk_i], video_idx[k].item())) + # if video_idx[k].item() != data.dataset.visualization[topk_i][0]: + # continue + question = '' + for word in dict[str(qid.item())][1]: + if word != 0: + question += vocab[word.item()] + ' ' + + caption_visualization.setdefault(qid.item(),[]).append( + { 'caption': raw_caption['video'+str(video_idx[k].item())+'.mp4'][0], + 'video_id': video_idx[k].item(), + 'retrieval_vid': data.dataset.visualization[topk_i][0], + 'top-'+str(n)+'retrieval_cap': raw_caption['video'+str(data.dataset.visualization[topk_i][0])+'.mp4'][data.dataset.visualization[topk_i][1]], + 'question': question, + 'answer': answer_vocab[answers[k].item()], + 'prediction': answer_vocab[preds[k].item()] + }) + ############################################# + all_scores = np.concatenate(all_scores, axis= 0) # all_v all_c + + n_q, n_m = all_scores.shape + gt_ranks = np.zeros((n_q, ), np.int32) + + for i in range(n_q): + s = all_scores[i] + sorted_idxs = np.argsort(-s) + + rank = n_m + for k in video_idx2_cap_gt[str(video_name[i])]: + tmp = np.where(sorted_idxs == k)[0][0] + if tmp < rank: + rank = tmp + gt_ranks[i] = rank + + r1 = 100 * len(np.where(gt_ranks < 1)[0]) / n_q + r5 = 100 * len(np.where(gt_ranks < 5)[0]) / n_q + r10 = 100 * len(np.where(gt_ranks < 10)[0]) / n_q + # r1, r5, r10 = 0,0,0 + + logging.info("r1: {:.4f} r5: {:.4f} r10: {:.4f}".format(r1,r5,r10)) + ############################################# + + acc = total_acc / count + if test: + output_dir = os.path.join(cfg.dataset.save_dir,'visualization') + if not os.path.exists(output_dir): + os.makedirs(output_dir) + else: + assert os.path.isdir(output_dir) + preds_file = os.path.join(output_dir, "visualization.json") + with open(preds_file,'w') as f: + json.dump(caption_visualization, f) +# what 2 +# who 21 +# how 54 +# where 505 +# when 405 +# question_type_acc = {2:0,21:0,54:0,505:0,405:0} + if cfg.dataset.name == 'msvd-qa': + logging.info("What:{:.4f} Who: {:.4f} How: {:.4f} When: {:.4f} Where: {:.4f}".format( + question_type_acc_msvd[2]/question_type_total_length_msvd[2], + question_type_acc_msvd[21]/question_type_total_length_msvd[21], + question_type_acc_msvd[54]/question_type_total_length_msvd[54], + question_type_acc_msvd[405]/question_type_total_length_msvd[405], + question_type_acc_msvd[505]/question_type_total_length_msvd[505] + )) + elif cfg.dataset.name == 'msrvtt-qa': + logging.info("What:{:.4f} Who: {:.4f} How: {:.4f} When: {:.4f} Where: {:.4f}".format( + question_type_acc_msrvtt[10]/question_type_total_length_msrvtt[10], + question_type_acc_msrvtt[2]/question_type_total_length_msrvtt[2], + question_type_acc_msrvtt[64]/question_type_total_length_msrvtt[64], + question_type_acc_msrvtt[310]/question_type_total_length_msrvtt[310], + question_type_acc_msrvtt[457]/question_type_total_length_msrvtt[457] + )) + if not write_preds: + return acc, r1, r5, r10 + else: + return acc, all_preds, gts, v_ids, q_ids + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='configs/tgif_qa_action.yml', type=str) + args = parser.parse_args() + if args.cfg_file is not None: + cfg_from_file(args.cfg_file) + + assert cfg.dataset.name in ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] + assert cfg.dataset.question_type in ['frameqa', 'count', 'transition', 'action', 'none'] + # check if the data folder exists + assert os.path.exists(cfg.dataset.data_dir) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name, cfg.dataset.pretrained) + cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name) + ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') + ckpt_retrieval = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model_retrieval.pt') + + assert os.path.exists(ckpt) and os.path.exists(ckpt_retrieval) + # load pretrained model + loaded = torch.load(ckpt, map_location='cpu') + model_kwargs = loaded['model_kwargs'] + + loaded_retrieval = torch.load(ckpt_retrieval, map_location='cpu') + model_kwargs_retrieval = loaded_retrieval['model_kwargs'] + + if cfg.dataset.name == 'tgif-qa': + cfg.dataset.test_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.test_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name, cfg.dataset.question_type)) + + cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name, cfg.dataset.question_type)) + cfg.dataset.object_feat = '/mnt/hdd2/zhanghaonan/object_features.h5' + else: + cfg.dataset.question_type = 'none' + cfg.dataset.appearance_feat = '{}_appearance_feat.h5' + cfg.dataset.motion_feat = '{}_motion_feat.h5' + cfg.dataset.object_feat = '{}_object_feat.h5' + cfg.dataset.vocab_json = '{}_vocab.json' + cfg.dataset.test_question_pt = '{}_test_questions.pt' + + cfg.dataset.test_question_pt = os.path.join(cfg.dataset.data_dir, + cfg.dataset.test_question_pt.format(cfg.dataset.name)) + cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name)) + + cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) + cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) + cfg.dataset.object_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.object_feat.format(cfg.dataset.name)) + + test_loader_kwargs = { + 'split' : 'test', + 'name' : cfg.dataset.name, + 'caption_max_num' : cfg.dataset.max_cap_num, + 'question_type': cfg.dataset.question_type, + 'question_pt': cfg.dataset.test_question_pt, + 'vocab_json': cfg.dataset.vocab_json, + 'appearance_feat': cfg.dataset.appearance_feat, + 'motion_feat': cfg.dataset.motion_feat, + 'object_feat' : cfg.dataset.object_feat, + 'val_num': cfg.val.val_num, + 'batch_size': cfg.train.batch_size, + 'num_workers': cfg.num_workers, + 'shuffle': False + + } + + test_loader = VideoQADataLoader(**test_loader_kwargs) + model_kwargs.update({'vocab': test_loader.vocab}) + model_kwargs.update({'visualization': cfg.test.visualization}) + model = PKOL.PKOL_Net(**model_kwargs).to(device) + model.load_state_dict(loaded['state_dict'], strict=False) + + model_kwargs_retrieval.update({'vocab': test_loader.vocab}) + retrieve_model = RetrieveNetwork(**model_kwargs_retrieval).to(device) + retrieve_model.load_state_dict(loaded_retrieval['state_dict'], strict=False) + + if cfg.test.write_preds: + acc, preds, gts, v_ids, q_ids = validate(cfg, model, retrieve_model, test_loader, device, write_preds=True, test=cfg.test.visualization) + + sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( + test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) + sys.stdout.flush() + + # write predictions for visualization purposes + output_dir = os.path.join(cfg.dataset.save_dir, 'preds') + if not os.path.exists(output_dir): + os.makedirs(output_dir) + else: + assert os.path.isdir(output_dir) + preds_file = os.path.join(output_dir, "test_preds.json") + + if cfg.dataset.question_type in ['action', 'transition']: \ + # Find groundtruth questions and corresponding answer candidates + vocab = test_loader.vocab['question_answer_idx_to_token'] + dict = {} + with open(cfg.dataset.test_question_pt, 'rb') as f: + obj = pickle.load(f) + questions = obj['questions'] + org_v_ids = obj['video_idx'] + org_v_names = obj['video_names'] + org_q_ids = obj['question_id'] + ans_candidates = obj['ans_candidates'] + + for idx in range(len(org_q_ids)): + dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx], ans_candidates[idx]] + instances = [ + {'video_id': video_id, 'question_id': q_id, 'video_name': dict[str(q_id)][0], 'question': [vocab[word.item()] for word in dict[str(q_id)][1] if word != 0], + 'answer': answer, + 'prediction': pred} for video_id, q_id, answer, pred in + zip(np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds)] + # write preditions to json file + with open(preds_file, 'w') as f: + json.dump(instances, f) + sys.stdout.write('Display 10 samples...\n') + # Display 10 samples + for idx in range(10): + print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) + cur_question = [vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0] + print('Question: ' + ' '.join(cur_question) + '?') + all_answer_cands = dict[str(q_ids[idx].item())][2] + for cand_id in range(len(all_answer_cands)): + cur_answer_cands = [vocab[word.item()] for word in all_answer_cands[cand_id] if word + != 0] + print('({}): '.format(cand_id) + ' '.join(cur_answer_cands)) + print('Prediction: {}'.format(preds[idx])) + print('Groundtruth: {}'.format(gts[idx])) + else: + vocab = test_loader.vocab['question_idx_to_token'] + dict = {} + with open(cfg.dataset.test_question_pt, 'rb') as f: + obj = pickle.load(f) + questions = obj['questions'] + org_v_ids = obj['video_ids'] + org_v_names = obj['video_names'] + org_q_ids = obj['question_id'] + + for idx in range(len(org_q_ids)): + dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx]] + instances = [ + {'video_id': video_id, 'question_id': q_id, 'video_name': str(dict[str(q_id)][0]), 'question': [vocab[word.item()] for word in dict[str(q_id)][1] if word != 0], + 'answer': answer, + 'prediction': pred} for video_id, q_id, answer, pred in + zip(np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds)] + # write preditions to json file + with open(preds_file, 'w') as f: + json.dump(instances, f) + sys.stdout.write('Display 10 samples...\n') + # Display 10 examples + for idx in range(10): + print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) + cur_question = [vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0] + print('Question: ' + ' '.join(cur_question) + '?') + print('Prediction: {}'.format(preds[idx])) + print('Groundtruth: {}'.format(gts[idx])) + else: + acc, _, _, _ = validate(cfg, model, retrieve_model, test_loader, device, write_preds=False, test=cfg.test.visualization) + # acc = validate(cfg, model, test_loader, device, cfg.test.write_preds) + sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( + test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) + sys.stdout.flush()