diff --git a/base_model.py b/base_model.py index a58a2d0..6c520bd 100644 --- a/base_model.py +++ b/base_model.py @@ -157,14 +157,14 @@ def build_BAN(dataset, args, priotize_using_counter=False): if len(args.maml_nums) > 1: maml_v_emb = [] for model_t in args.maml_nums: - weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb_temp = MAML(args.VQA_dir) maml_v_emb_temp.load_state_dict(torch.load(weight_path)) maml_v_emb.append(maml_v_emb_temp) else: - weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb = MAML(args.VQA_dir) @@ -223,14 +223,14 @@ def build_SAN(dataset, args): if len(args.maml_nums) > 1: maml_v_emb = [] for model_t in args.maml_nums: - weight_path = args.VQA_dir + '/' + 't%s_'%(model_t) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_'%(model_t) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) # maml_v_emb = SimpleCNN32(weight_path, args.eps_cnn, args.momentum_cnn) maml_v_emb_temp = MAML(args.VQA_dir) maml_v_emb_temp.load_state_dict(torch.load(weight_path)) maml_v_emb.append(maml_v_emb_temp) else: - weight_path = args.VQA_dir + '/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path + weight_path = args.VQA_dir + '/maml/' + 't%s_' % (args.maml_nums[0]) + args.maml_model_path print('load initial weights MAML from: %s' % (weight_path)) maml_v_emb = MAML(args.VQA_dir) maml_v_emb.load_state_dict(torch.load(weight_path))