Skip to content

Commit

Permalink
"fix: maml path"
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanbinh-nguyen96 committed Jul 2, 2021
1 parent 6d8bc9f commit 100e42b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ def build_BAN(dataset, args, priotize_using_counter=False):
if len(args.maml_nums) > 1:
maml_v_emb = []
for model_t in args.maml_nums:
weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb_temp = MAML(args.VQA_dir)
maml_v_emb_temp.load_state_dict(torch.load(weight_path))
maml_v_emb.append(maml_v_emb_temp)
else:
weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb = MAML(args.VQA_dir)
Expand Down Expand Up @@ -223,14 +223,14 @@ def build_SAN(dataset, args):
if len(args.maml_nums) > 1:
maml_v_emb = []
for model_t in args.maml_nums:
weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path
weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
# maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn)
maml_v_emb_temp = MAML(args.VQA_dir)
maml_v_emb_temp.load_state_dict(torch.load(weight_path))
maml_v_emb.append(maml_v_emb_temp)
else:
weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path
print('load initial weights MAML from: %s' % (weight_path))
maml_v_emb = MAML(args.VQA_dir)
maml_v_emb.load_state_dict(torch.load(weight_path))
Expand Down

0 comments on commit 100e42b

Please sign in to comment.