From 100e42b8f89b1f3552e35ac8a9a16f9c89dffd33 Mon Sep 17 00:00:00 2001 From: Nguyen Xuan Binh Date: Fri, 2 Jul 2021 16:45:54 +0700 Subject: [PATCH] "fix: maml path" --- base_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/base_model.py b/base_model.py index a58a2d0..6c520bd 100644 --- a/base_model.py +++ b/base_model.py @@ -157,14 +157,14 @@ def build_BAN(dataset, args, priotize_using_counter=False): if len(args.maml_nums) > 1: maml_v_emb = [] for model_t in args.maml_nums: - weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb_temp = MAML(args.VQA_dir) maml_v_emb_temp.load_state_dict(torch.load(weight_path)) maml_v_emb.append(maml_v_emb_temp) else: - weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb = MAML(args.VQA_dir) @@ -223,14 +223,14 @@ def build_SAN(dataset, args): if len(args.maml_nums) > 1: maml_v_emb = [] for model_t in args.maml_nums: - weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb_temp = MAML(args.VQA_dir) maml_v_emb_temp.load_state_dict(torch.load(weight_path)) maml_v_emb.append(maml_v_emb_temp) else: - weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) maml_v_emb = MAML(args.VQA_dir) maml_v_emb.load_state_dict(torch.load(weight_path))