Skip to content

Commit

Permalink
Code clean-ups, #24
Browse files Browse the repository at this point in the history
  • Loading branch information
chao1224 committed Apr 19, 2024
1 parent 64d75c9 commit 9ba1584
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 200 deletions.
153 changes: 2 additions & 151 deletions MoleculeSTM/datasets/MoleculeNet_Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,161 +6,12 @@
import numpy as np
import pandas as pd
import torch
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch.utils import data
from torch_geometric.data import (Data, InMemoryDataset, download_url, extract_zip)


def mol_to_graph_data_obj_simple(mol):
""" used in MoleculeNetGraphDataset() class
Converts rdkit mol objects to graph data object in pytorch geometric
NB: Uses simplified atom and bond features, and represent as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr """

# atoms
# num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = atom_to_feature_vector(atom)
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

# bonds
if len(mol.GetBonds()) <= 0: # mol has no bonds
num_bond_features = 3 # bond type & direction
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
else: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)

edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)

# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

return data


def graph_data_obj_to_nx_simple(data):
""" torch geometric -> networkx
NB: possible issues with recapitulating relative
stereochemistry since the edges in the nx object are unordered.
:param data: pytorch geometric Data object
:return: networkx object """
G = nx.Graph()

# atoms
atom_features = data.x.cpu().numpy()
num_atoms = atom_features.shape[0]
for i in range(num_atoms):
temp_feature = atom_features[i]
G.add_node(
i,
x0=temp_feature[0],
x1=temp_feature[1],
x2=temp_feature[2],
x3=temp_feature[3],
x4=temp_feature[4],
x5=temp_feature[5],
x6=temp_feature[6],
x7=temp_feature[7],
x8=temp_feature[8])
pass

# bonds
edge_index = data.edge_index.cpu().numpy()
edge_attr = data.edge_attr.cpu().numpy()
num_bonds = edge_index.shape[1]
for j in range(0, num_bonds, 2):
begin_idx = int(edge_index[0, j])
end_idx = int(edge_index[1, j])
temp_feature= edge_attr[j]
if not G.has_edge(begin_idx, end_idx):
G.add_edge(begin_idx, end_idx,
e0=temp_feature[0],
e1=temp_feature[1],
e2=temp_feature[2])

return G


def nx_to_graph_data_obj_simple(G):
""" vice versa of graph_data_obj_to_nx_simple()
Assume node indices are numbered from 0 to num_nodes - 1.
NB: Uses simplified atom and bond features, and represent as indices.
NB: possible issues with recapitulating relative stereochemistry
since the edges in the nx object are unordered. """

# atoms
# num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for _, node in G.nodes(data=True):
atom_feature = [node['x0'], node['x1'], node['x2'], node['x3'], node['x4'], node['x5'], node['x6'], node['x7'], node['x8']]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

# bonds
num_bond_features = 3 # bond type, bond direction
if len(G.edges()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for i, j, edge in G.edges(data=True):
edge_feature = [edge['e0'], edge['e1'], edge['e2']]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)

# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

return data


def create_standardized_mol_id(smiles):
""" smiles -> inchi """

if check_smiles_validity(smiles):
# remove stereochemistry
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
isomericSmiles=False)
mol = AllChem.MolFromSmiles(smiles)
if mol is not None:
# to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)\
# c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
if '.' in smiles: # if multiple species, pick largest molecule
mol_species_list = split_rdkit_mol_obj(mol)
largest_mol = get_largest_mol(mol_species_list)
inchi = AllChem.MolToInchi(largest_mol)
else:
inchi = AllChem.MolToInchi(mol)
return inchi
return
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple


class MoleculeNetGraphDataset(InMemoryDataset):
Expand Down
60 changes: 11 additions & 49 deletions MoleculeSTM/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,26 @@
import torch
from rdkit import Chem
from torch_geometric.data import Data


allowable_features = {
'possible_atomic_num_list': list(range(1, 119)),
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
'possible_chirality_list': [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
],
'possible_hybridization_list': [
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
Chem.rdchem.HybridizationType.UNSPECIFIED
],
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'possible_bonds': [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC
],
'possible_bond_dirs': [ # only for double bond stereo information
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
}
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector


def mol_to_graph_data_obj_simple(mol):
""" used in MoleculeNetGraphDataset() class
Converts rdkit mol objects to graph data object in pytorch geometric
NB: Uses simplified atom and bond features, and represent as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr """

# atoms
# num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atomic_num = atom.GetAtomicNum()
chiral_tag = atom.GetChiralTag()
if atomic_num == 0:
atomic_num = 118 # Only for one extreme case
atom_feature = [allowable_features['possible_atomic_num_list'].index(atomic_num)] + \
[allowable_features['possible_chirality_list'].index(chiral_tag)]
atom_feature = atom_to_feature_vector(atom)
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

# bonds
if len(mol.GetBonds()) <= 0: # mol has no bonds
num_bond_features = 2 # bond type & direction
num_bond_features = 3 # bond type & direction
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
else: # mol has bonds
Expand All @@ -66,12 +32,8 @@ def mol_to_graph_data_obj_simple(mol):
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
bond_type = bond.GetBondType()
bond_dir = bond.GetBondDir()
if bond_dir not in allowable_features['possible_bond_dirs']:
bond_dir = 0
edge_feature = [allowable_features['possible_bonds'].index(bond_type)] + \
[allowable_features['possible_bond_dirs'].index(bond_dir)]
edge_feature = bond_to_feature_vector(bond)

edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
Expand Down

0 comments on commit 9ba1584

Please sign in to comment.