-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
40ffb9f
commit 9dd007a
Showing
11 changed files
with
9,107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
/dist | ||
/NeuroPred-PLM.egg-info | ||
/NeuroPred-PLM/__pycache__ | ||
/.vscode | ||
.idea | ||
*.iml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include NeuroPredPLM/args.pt |
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
main model | ||
""" | ||
import torch | ||
from torch import nn | ||
import numpy as np | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
import os | ||
|
||
from .utils import length_to_mask, load_model_and_alphabet_core | ||
|
||
|
||
class EsmModel(nn.Module): | ||
def __init__(self, hidden_size=64, num_labels=2, projection_size=24, head=12): | ||
super().__init__() | ||
|
||
basedir = os.path.abspath(os.path.dirname(__file__)) | ||
self.esm, self.alphabet = load_model_and_alphabet_core(os.path.join(basedir, 'args.pt')) | ||
self.num_labels = num_labels | ||
self.head = head | ||
self.hidden_size = hidden_size | ||
self.projection = nn.Linear(hidden_size, projection_size) | ||
self.cov_1 = nn.Conv1d(projection_size, projection_size, kernel_size=3, padding='same') | ||
self.cov_2 = nn.Conv1d(projection_size, int(projection_size/2), kernel_size=1, padding='same') | ||
# self.gating = nn.Linear(projection_size, projection_size) | ||
self.W = nn.Parameter(torch.randn((head, int(projection_size/2)))) | ||
# self.mu = nn.Parameter(torch.randn((1, 768))) | ||
self.fcn = nn.Sequential(nn.Linear(int(projection_size/2)*head, int(projection_size/2)), | ||
nn.ReLU(), nn.Linear(int(projection_size/2), num_labels)) | ||
|
||
|
||
def forward(self, peptide_list, device='cpu'): | ||
peptide_length = [len(i[1]) for i in peptide_list] | ||
batch_converter = self.alphabet.get_batch_converter() | ||
_, _, batch_tokens = batch_converter(peptide_list) | ||
batch_tokens = batch_tokens.to(device) | ||
protein_dict = self.esm(batch_tokens, repr_layers=[12], return_contacts=False) | ||
protein_embeddings = protein_dict["representations"][12][:, 1:, :] | ||
protein_embed = rearrange(protein_embeddings, 'b l (h d)-> (b h) l d', h=self.head) | ||
representations = self.projection(protein_embed) | ||
representations = rearrange(representations, 'b l d -> b d l') | ||
representation_cov = F.relu(self.cov_1(representations)) | ||
representation_cov = F.relu(self.cov_2(representation_cov)) | ||
representations = rearrange(representation_cov, '(b h) d l -> b h l d', h=self.head) | ||
att = torch.einsum('bhld,hd->bhl', representations, self.W) | ||
mask = length_to_mask(torch.tensor(peptide_length)).to(device) | ||
att = att.masked_fill(mask.unsqueeze(1)==0, -np.inf) | ||
att= F.softmax(att, dim=-1) | ||
# print(att) | ||
representations = rearrange(representations * att.unsqueeze(-1), 'b h l d -> b l (h d)') | ||
representations = torch.sum(representations, dim=1) | ||
return self.fcn(representations), att | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from .model import EsmModel | ||
from .utils import load_hub_workaround | ||
import torch | ||
|
||
MODEL_URL = "https://zenodo.org/record/7042286/files/model.pth" | ||
|
||
def predict(peptide_list, device='cpu'): | ||
with torch.no_grad(): | ||
neuroPred_model = EsmModel() | ||
neuroPred_model.eval() | ||
state_dict = load_hub_workaround(MODEL_URL) | ||
# state_dict = torch.load("/mnt/d/protein-net/Neuropep-ESM/model.pth", map_location="cpu") | ||
neuroPred_model.load_state_dict(state_dict) | ||
neuroPred_model = neuroPred_model.to(device) | ||
prob, att = neuroPred_model(peptide_list, device) | ||
pred = torch.argmax(prob, dim=-1).cpu().tolist() | ||
att = att.cpu().numpy() | ||
out = {i[0]:[j,m[:, :len(i[1])]] for i, j, m in zip(peptide_list, pred, att)} | ||
return out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import esm | ||
from argparse import Namespace | ||
import pathlib | ||
import urllib | ||
|
||
def length_to_mask(length, max_len=None, dtype=None): | ||
"""length: B. | ||
return B x max_len. | ||
If max_len is None, then max of length will be used. | ||
""" | ||
assert len(length.shape) == 1, 'Length shape should be 1 dimensional.' | ||
max_len = max_len or length.max().item() | ||
mask = torch.arange(max_len, device=length.device, | ||
dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1) | ||
if dtype is not None: | ||
mask = torch.as_tensor(mask, dtype=dtype, device=length.device) | ||
return mask | ||
|
||
|
||
def load_model_and_alphabet_core(args_dict, regression_data=None): | ||
args_dict = torch.load(args_dict) | ||
alphabet = esm.Alphabet.from_architecture(args_dict["args"].arch) | ||
|
||
# upgrade state dict | ||
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s) | ||
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s) | ||
model_args = {pra(arg[0]): arg[1] for arg in vars(args_dict["args"]).items()} | ||
model_type = esm.ProteinBertModel | ||
|
||
model = model_type( | ||
Namespace(**model_args), | ||
alphabet, | ||
) | ||
return model, alphabet | ||
|
||
|
||
def load_hub_workaround(url): | ||
try: | ||
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu") | ||
except RuntimeError: | ||
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106 | ||
fn = pathlib.Path(url).name | ||
data = torch.load( | ||
f"{torch.hub.get_dir()}/checkpoints/{fn}", | ||
map_location="cpu", | ||
) | ||
except urllib.error.HTTPError as e: | ||
raise Exception(f"Could not load {url}, check your network!") | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
## NeuroPred-PLM: an interpretable and robust model for prediction of neuropeptides by protein language model | ||
|
||
### Requirements | ||
To install requirements: | ||
|
||
``` | ||
pip install git+https://github.com/ISYSLAB-HUST/NeuroPred-PLM.git | ||
``` | ||
### Usage | ||
|
||
``` | ||
import torch | ||
from NeuroPredPLM.predict import predict | ||
data = [ | ||
("peptide_1", "IGLRLPNMLKF"), | ||
("peptide_2", "QAAQFKVWSASELVD"), | ||
("peptide_3","LRSPKMMHKSGCFGRRLDRIGSLSGLGCNVLRKY") | ||
] | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
neuropeptide_pred = predict(data,device) | ||
# {peptide_id:[Type:int(1->neuropeptide,0->non-neuropeptide), attention score:nd.array]} | ||
``` | ||
|
||
### Contact | ||
If you have any questions, comments, or would like to report a bug, please file a Github issue or contact me at [email protected]. |
Oops, something went wrong.