-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtmh_main.py
64 lines (52 loc) · 2.31 KB
/
tmh_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
"""
import json
import sys
import torch
import esm
from deepTMpred.model import FineTuneEsmCNN, OrientationNet
from deepTMpred.utils import tmh_predict, load_model_and_alphabet_core
from torch.utils.data import DataLoader
from deepTMpred.data import FineTuneDataset, batch_collate
def data_iter(data_path, pssm_dir, hmm_dir, batch_converter, label=False):
data = FineTuneDataset(data_path, pssm_dir=pssm_dir, hmm_file=hmm_dir, label=label)
test = DataLoader(data, len(data), collate_fn=batch_collate(batch_converter, label=label))
return test
def test(model, orientation_model, test_loader, device):
model.eval()
with torch.no_grad():
for tokens, _ids, matrix, token_lengths in test_loader:
tokens = tokens.to(device)
results = model.esm(tokens, repr_layers=[12], return_contacts=False)
token_embeddings = results["representations"][12][:, 1:, :]
token_lengths = token_lengths.to(device)
matrix = matrix.to(device)
embedings = torch.cat((matrix, token_embeddings), dim=2)
predict_list, prob = model.predict(embedings, token_lengths)
orientation_out = orientation_model(embedings)
predict = torch.argmax(orientation_out, dim=1)
tmh_dict = tmh_predict(_ids, predict_list, prob, predict.tolist())
return tmh_dict
def main():
###############
test_file = sys.argv[3]
tmh_model_path = sys.argv[1]
orientation_model_path = sys.argv[2]
device = torch.device('cpu')
###############
model = FineTuneEsmCNN(768)
# pretrain_model, alphabet = esm.pretrained.esm1_t12_85M_UR50S()
args_dict = torch.load('./args.pt')
pretrain_model, alphabet = load_model_and_alphabet_core(args_dict)
batch_converter = alphabet.get_batch_converter()
model.add_module('esm', pretrain_model.to(device))
model.load_state_dict(torch.load(tmh_model_path))
model = model.to(device)
orientation_model = OrientationNet()
orientation_model.load_state_dict(torch.load(orientation_model_path))
orientation_model = orientation_model.to(device)
test_iter = data_iter(test_file, None, None, batch_converter, label=False)
tmh_dict = test(model, orientation_model, test_iter, device)
json.dump(tmh_dict, open('test.json', 'w'))
if __name__ == "__main__":
main()