-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathevaluate_interact_single.py
More file actions
122 lines (93 loc) · 5.21 KB
/
evaluate_interact_single.py
File metadata and controls
122 lines (93 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
repo_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(repo_dir)
from repo.tools.interaction import *
from repo.utils import misc
import argparse
import torch
from scipy import spatial as sci_spatial
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--result_path', type=str, default='../results/denovo/diffbp/selftrain/ABL2_HUMAN_274_551_0/4xli_B_rec_4xli_1n1_lig_tt_min_0_pocket10')
parser.add_argument('--pdb_path', type=str, default='../data/crossdocked_test/ABL2_HUMAN_274_551_0/4xli_B_rec_4xli_1n1_lig_tt_min_0_pocket10.pdb')
parser.add_argument('--save_path', type=str, default='./tmp/interaction')
args = parser.parse_args()
logger = misc.get_logger('evaluate_interaction', log_dir=args.result_path)
reports = {'num_hydrophobic': 0, 'num_hydrogen': 0, 'num_wb': 0, 'num_pi_stack': 0, 'num_pi_cation': 0, 'num_halogen': 0, 'num_metal': 0}
interaction_detected = 0
result_path = args.result_path
file_list = os.listdir(result_path)
file_list = [file_name for file_name in file_list if file_name.endswith('.sdf')]
report_dicts = []
n_eval_success = 0
for file_name in file_list:
try:
protein_file = args.pdb_path
ligand_file = os.path.join(result_path, file_name)
interaction_analyzer = InteractionAnalyzer(protein_file, ligand_file)
report_path = interaction_analyzer.plip_analysis(args.save_path)
report_dict = interaction_analyzer.plip_parser(report_path)
report_dicts.append(report_dict)
n_eval_success += 1
except:
pass
interact_ratio_per_mol = {key: [] for key in report_dicts[0].keys()}
for report_dict in report_dicts:
for key, val in report_dict.items():
interact_ratio_per_mol[key].append(val)
for key, val in interact_ratio_per_mol.items():
interact_ratio_per_mol[key] = np.sum(val) / n_eval_success
interact_distribution = {key: [] for key in report_dicts[0].keys()}
interact_distribution_sum = []
for report_dict in report_dicts:
for key, val in report_dict.items():
interact_distribution[key].append(val)
interact_distribution_sum.append(val)
interact_distribution_sum = np.sum(interact_distribution_sum)
for key, val in interact_distribution.items():
interact_distribution[key] = np.sum(interact_distribution[key]) / interact_distribution_sum
gen_dict = {'dist': interact_distribution, 'ratio': interact_ratio_per_mol, 'n_eval_success': n_eval_success}
ref_mol_path = os.path.join(os.path.dirname(args.pdb_path), '_'.join(os.path.basename(args.pdb_path).split('_')[:-1]) + '.sdf')
interaction_analyzer = InteractionAnalyzer(protein_file, ligand_file)
report_path = interaction_analyzer.plip_analysis(args.save_path)
report_ref_dict = interaction_analyzer.plip_parser(report_path)
num_interact = np.sum([v for k,v in report_dict.items()])
report_ref_dist_dict = {k: v/num_interact for k,v in report_dict.items()}
ref_dict = {'dist': report_ref_dist_dict, 'ratio': report_ref_dict}
torch.save(gen_dict, os.path.join(result_path, 'interact_gen_results.pt'))
torch.save(ref_dict, os.path.join(result_path, 'interact_ref_results.pt'))
# jsd = sci_spatial.distance.jensenshannon([v for _, v in ref_dict['dist'].items()],
# [v for _, v in gen_dict['dist'].items()]) # 100 jsd - > mean
# mae = np.abs([v for _, v in ref_dict['ratio'].items()] - [v for _, v in gen_dict['ratio'].items()]).mean()
# print('jsd: ', jsd, 'mae: ', mae)
# num_success = 0
# jsds = []
# maes = []
# num_interactions = []
# num_interactions_ref = []
# for file_gen, file_ref in files:
# dist_gen = file_gen['dist']
# dist_ref = file_ref['dist']
# jsd = sci_spatial.distance.jensenshannon([v for _, v in dist_ref.items()],
# [v for _, v in dist_gen.items()])
# jsds.append(jsd)
# ratio_gen = file_gen['ratio']
# ratio_ref = file_ref['ratio']
# mae = ([v for _, v in ratio_gen.items()] - [v for _, v in ratio_ref.items()]).abs().mean()
# maes.append(mae)
# n_eval_success = file_gen['n_eval_success']
# num_interact_per_struct = np.array([v for _,v in ratio_gen.items()]) * n_eval_success
# num_success += n_eval_success
# num_interactions.append(num_interact_per_struct)
# num_interactions_ref.append(ratio_ref)
# num_all_interact = np.stack(num_interactions, dim=0).sum(0)
# ratio_all_interact = num_all_interact / num_success
# dist_all_interact = num_all_interact / num_all_interact.sum()
# num_ref_interact = np.stack(num_interactions_ref, dim=0).sum(0)
# ratio_ref_interact = num_ref_interact / len(num_interactions_ref)
# dist_ref_interact = num_ref_interact / num_ref_interact.sum()
# jsd_overall = sci_spatial.distance.jensenshannon(dist_ref_interact, dist_all_interact)
# mae_overall = compute_mae(ratio_ref_interact, ratio_all_interact)
# print(maes.mean(), jsds.mean())