Skip to content

Commit

Permalink
"fix: variable's name data dir"
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanbinh-nguyen96 committed Jun 22, 2021
1 parent 45aa122 commit 59bf082
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 289 deletions.
22 changes: 11 additions & 11 deletions base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,22 @@ def build_BAN(dataset, args, priotize_using_counter=False):
if len(args.maml_nums) > 1:
maml_v_emb = []
for model_t in args.maml_nums:
weight_path = args.RAD_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb_temp = MAML(args.RAD_dir)
maml_v_emb_temp = MAML(args.VQA_dir)
maml_v_emb_temp.load_state_dict(torch.load(weight_path))
maml_v_emb.append(maml_v_emb_temp)
else:
weight_path = args.RAD_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb = MAML(args.RAD_dir)
maml_v_emb = MAML(args.VQA_dir)
maml_v_emb.load_state_dict(torch.load(weight_path))
# build and load pre-trained Auto-encoder model
if args.autoencoder:
ae_v_emb = Auto_Encoder_Model()
weight_path = args.RAD_dir + '/' + args.ae_model_path
weight_path = args.VQA_dir + '/' + args.ae_model_path
print('load initial weights DAE from: %s'%(weight_path))
ae_v_emb.load_state_dict(torch.load(weight_path))
# Loading tfidf weighted embedding
Expand Down Expand Up @@ -214,28 +214,28 @@ def build_SAN(dataset, args):
args.dropout)
# build and load pre-trained MAML model
if args.maml:
# weight_path = args.RAD_dir + '/' + args.maml_model_path
# weight_path = args.VQA_dir + '/' + args.maml_model_path
# print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN(weight_path, args.eps_cnn, args.momentum_cnn)
if len(args.maml_nums) > 1:
maml_v_emb = []
for model_t in args.maml_nums:
weight_path = args.RAD_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb_temp = MAML(args.RAD_dir)
maml_v_emb_temp = MAML(args.VQA_dir)
maml_v_emb_temp.load_state_dict(torch.load(weight_path))
maml_v_emb.append(maml_v_emb_temp)
else:
weight_path = args.RAD_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb = MAML(args.RAD_dir)
maml_v_emb = MAML(args.VQA_dir)
maml_v_emb.load_state_dict(torch.load(weight_path))
# build and load pre-trained Auto-encoder model
if args.autoencoder:
ae_v_emb = Auto_Encoder_Model()
weight_path = args.RAD_dir + '/' + args.ae_model_path
weight_path = args.VQA_dir + '/' + args.ae_model_path
print('load initial weights DAE from: %s'%(weight_path))
ae_v_emb.load_state_dict(torch.load(weight_path))
# Loading tfidf weighted embedding
Expand Down
12 changes: 6 additions & 6 deletions dataset_pathVQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, name, args, dictionary, dataroot='data', question_len=12):
super(VQAFeatureDataset, self).__init__()
self.args = args
assert name in ['train', 'val', 'test']
dataroot = args.RAD_dir
dataroot = args.VQA_dir
ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl')
label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl')
self.ans2label = cPickle.load(open(ans2label_path, 'rb'))
Expand All @@ -169,7 +169,7 @@ def __init__(self, name, args, dictionary, dataroot='data', question_len=12):
# load image data for Auto-encoder module
if args.autoencoder:
# TODO: load images
if 'RAD' in self.args.RAD_dir:
if 'RAD' in self.args.VQA_dir:
images_path = os.path.join(dataroot, 'images128x128.pkl')
else:
images_path = os.path.join(dataroot, 'pytorch_images128_ae.pkl')
Expand Down Expand Up @@ -206,7 +206,7 @@ def tensorize(self):
self.maml_images_data = torch.stack(self.maml_images_data)
self.maml_images_data = self.maml_images_data.type('torch.FloatTensor')
if self.args.autoencoder:
if 'RAD' in self.args.RAD_dir:
if 'RAD' in self.args.VQA_dir:
self.ae_images_data = torch.from_numpy(self.ae_images_data)
else:
self.ae_images_data = torch.stack(self.ae_images_data)
Expand Down Expand Up @@ -237,7 +237,7 @@ def __getitem__(self, index):

image_data = [0, 0]
if self.args.maml:
if 'RAD' in self.args.RAD_dir:
if 'RAD' in self.args.VQA_dir:
maml_images_data = self.maml_images_data[entry['image']].reshape(self.args.img_size * self.args.img_size)
else:
maml_images_data = self.maml_images_data[entry['image']].reshape(3 * self.args.img_size * self.args.img_size)
Expand All @@ -264,8 +264,8 @@ def tfidf_from_questions(names, args, dictionary, dataroot='data', target=['rad'
inds = [[], []] # rows, cols for uncoalesce sparse matrix
df = dict()
N = len(dictionary)
if args.use_RAD:
dataroot = args.RAD_dir
if args.use_VQA:
dataroot = args.VQA_dir
def populate(inds, df, text):
tokens = dictionary.tokenize(text, True)
for t in tokens:
Expand Down
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def parse_args():
help='dropout of rate of final classifier')

# Train with RAD
parser.add_argument('--use_RAD', action='store_true', default=False,
parser.add_argument('--use_VQA', action='store_true', default=False,
help='Using TDIUC dataset to train')
parser.add_argument('--RAD_dir', type=str,
parser.add_argument('--VQA_dir', type=str,
help='RAD dir')

# Optimization hyper-parameters
Expand Down Expand Up @@ -135,10 +135,10 @@ def parse_args():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# Load dictionary and RAD training dataset
if args.use_RAD:
dictionary = dataset_pathVQA.Dictionary.load_from_file(os.path.join(args.RAD_dir, 'dictionary.pkl'))
if args.use_VQA:
dictionary = dataset_pathVQA.Dictionary.load_from_file(os.path.join(args.VQA_dir, 'dictionary.pkl'))
train_dset = dataset_pathVQA.VQAFeatureDataset('train', args, dictionary, question_len=args.question_len)
if 'RAD' not in args.RAD_dir:
if 'RAD' not in args.VQA_dir:
val_dset = dataset_pathVQA.VQAFeatureDataset('val', args, dictionary, question_len=args.question_len)
batch_size = args.batch_size
# Create VQA model
Expand All @@ -157,7 +157,7 @@ def parse_args():
epoch = model_data['epoch'] + 1
# create training dataloader
train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0, collate_fn=utils.trim_collate, pin_memory=True)
if 'RAD' not in args.RAD_dir:
if 'RAD' not in args.VQA_dir:
eval_loader = DataLoader(val_dset, batch_size, shuffle=False, num_workers=0, collate_fn=utils.trim_collate, pin_memory=True)
else:
eval_loader = None
Expand Down
Loading

0 comments on commit 59bf082

Please sign in to comment.