Skip to content

Commit

Permalink
upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wangleiofficial committed Sep 2, 2022
1 parent 40ffb9f commit 9dd007a
Show file tree
Hide file tree
Showing 11 changed files with 9,107 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/dist
/NeuroPred-PLM.egg-info
/NeuroPred-PLM/__pycache__
/.vscode
.idea
*.iml
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include NeuroPredPLM/args.pt
Empty file added NeuroPredPLM/__init__.py
Empty file.
Binary file added NeuroPredPLM/args.pt
Binary file not shown.
55 changes: 55 additions & 0 deletions NeuroPredPLM/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
main model
"""
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange
import os

from .utils import length_to_mask, load_model_and_alphabet_core


class EsmModel(nn.Module):
def __init__(self, hidden_size=64, num_labels=2, projection_size=24, head=12):
super().__init__()

basedir = os.path.abspath(os.path.dirname(__file__))
self.esm, self.alphabet = load_model_and_alphabet_core(os.path.join(basedir, 'args.pt'))
self.num_labels = num_labels
self.head = head
self.hidden_size = hidden_size
self.projection = nn.Linear(hidden_size, projection_size)
self.cov_1 = nn.Conv1d(projection_size, projection_size, kernel_size=3, padding='same')
self.cov_2 = nn.Conv1d(projection_size, int(projection_size/2), kernel_size=1, padding='same')
# self.gating = nn.Linear(projection_size, projection_size)
self.W = nn.Parameter(torch.randn((head, int(projection_size/2))))
# self.mu = nn.Parameter(torch.randn((1, 768)))
self.fcn = nn.Sequential(nn.Linear(int(projection_size/2)*head, int(projection_size/2)),
nn.ReLU(), nn.Linear(int(projection_size/2), num_labels))


def forward(self, peptide_list, device='cpu'):
peptide_length = [len(i[1]) for i in peptide_list]
batch_converter = self.alphabet.get_batch_converter()
_, _, batch_tokens = batch_converter(peptide_list)
batch_tokens = batch_tokens.to(device)
protein_dict = self.esm(batch_tokens, repr_layers=[12], return_contacts=False)
protein_embeddings = protein_dict["representations"][12][:, 1:, :]
protein_embed = rearrange(protein_embeddings, 'b l (h d)-> (b h) l d', h=self.head)
representations = self.projection(protein_embed)
representations = rearrange(representations, 'b l d -> b d l')
representation_cov = F.relu(self.cov_1(representations))
representation_cov = F.relu(self.cov_2(representation_cov))
representations = rearrange(representation_cov, '(b h) d l -> b h l d', h=self.head)
att = torch.einsum('bhld,hd->bhl', representations, self.W)
mask = length_to_mask(torch.tensor(peptide_length)).to(device)
att = att.masked_fill(mask.unsqueeze(1)==0, -np.inf)
att= F.softmax(att, dim=-1)
# print(att)
representations = rearrange(representations * att.unsqueeze(-1), 'b h l d -> b l (h d)')
representations = torch.sum(representations, dim=1)
return self.fcn(representations), att


20 changes: 20 additions & 0 deletions NeuroPredPLM/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .model import EsmModel
from .utils import load_hub_workaround
import torch

MODEL_URL = "https://zenodo.org/record/7042286/files/model.pth"

def predict(peptide_list, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
state_dict = load_hub_workaround(MODEL_URL)
# state_dict = torch.load("/mnt/d/protein-net/Neuropep-ESM/model.pth", map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
prob, att = neuroPred_model(peptide_list, device)
pred = torch.argmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
out = {i[0]:[j,m[:, :len(i[1])]] for i, j, m in zip(peptide_list, pred, att)}
return out

50 changes: 50 additions & 0 deletions NeuroPredPLM/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import esm
from argparse import Namespace
import pathlib
import urllib

def length_to_mask(length, max_len=None, dtype=None):
"""length: B.
return B x max_len.
If max_len is None, then max of length will be used.
"""
assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
max_len = max_len or length.max().item()
mask = torch.arange(max_len, device=length.device,
dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
if dtype is not None:
mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
return mask


def load_model_and_alphabet_core(args_dict, regression_data=None):
args_dict = torch.load(args_dict)
alphabet = esm.Alphabet.from_architecture(args_dict["args"].arch)

# upgrade state dict
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(args_dict["args"]).items()}
model_type = esm.ProteinBertModel

model = model_type(
Namespace(**model_args),
alphabet,
)
return model, alphabet


def load_hub_workaround(url):
try:
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
except RuntimeError:
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
fn = pathlib.Path(url).name
data = torch.load(
f"{torch.hub.get_dir()}/checkpoints/{fn}",
map_location="cpu",
)
except urllib.error.HTTPError as e:
raise Exception(f"Could not load {url}, check your network!")
return data
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## NeuroPred-PLM: an interpretable and robust model for prediction of neuropeptides by protein language model

### Requirements
To install requirements:

```
pip install git+https://github.com/ISYSLAB-HUST/NeuroPred-PLM.git
```
### Usage

```
import torch
from NeuroPredPLM.predict import predict
data = [
("peptide_1", "IGLRLPNMLKF"),
("peptide_2", "QAAQFKVWSASELVD"),
("peptide_3","LRSPKMMHKSGCFGRRLDRIGSLSGLGCNVLRKY")
]
device = "cuda" if torch.cuda.is_available() else "cpu"
neuropeptide_pred = predict(data,device)
# {peptide_id:[Type:int(1->neuropeptide,0->non-neuropeptide), attention score:nd.array]}
```

### Contact
If you have any questions, comments, or would like to report a bug, please file a Github issue or contact me at [email protected].
Loading

0 comments on commit 9dd007a

Please sign in to comment.