From af49d4dd9da72c1ae6a785d6dad1965dfe78d68b Mon Sep 17 00:00:00 2001 From: chao1224 Date: Fri, 19 Apr 2024 01:19:12 -0400 Subject: [PATCH] Code clean-ups, #24 --- MoleculeSTM/datasets/MoleculeNet_Graph.py | 153 +--------------------- MoleculeSTM/datasets/utils.py | 60 ++------- 2 files changed, 13 insertions(+), 200 deletions(-) diff --git a/MoleculeSTM/datasets/MoleculeNet_Graph.py b/MoleculeSTM/datasets/MoleculeNet_Graph.py index 4392598..f2f84d2 100644 --- a/MoleculeSTM/datasets/MoleculeNet_Graph.py +++ b/MoleculeSTM/datasets/MoleculeNet_Graph.py @@ -6,161 +6,12 @@ import numpy as np import pandas as pd import torch -from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector from rdkit import Chem from rdkit.Chem import AllChem, Descriptors from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect from torch.utils import data -from torch_geometric.data import (Data, InMemoryDataset, download_url, extract_zip) - - -def mol_to_graph_data_obj_simple(mol): - """ used in MoleculeNetGraphDataset() class - Converts rdkit mol objects to graph data object in pytorch geometric - NB: Uses simplified atom and bond features, and represent as indices - :param mol: rdkit mol object - :return: graph data object with the attributes: x, edge_index, edge_attr """ - - # atoms - # num_atom_features = 2 # atom type, chirality tag - atom_features_list = [] - for atom in mol.GetAtoms(): - atom_feature = atom_to_feature_vector(atom) - atom_features_list.append(atom_feature) - x = torch.tensor(np.array(atom_features_list), dtype=torch.long) - - # bonds - if len(mol.GetBonds()) <= 0: # mol has no bonds - num_bond_features = 3 # bond type & direction - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) - else: # mol has bonds - edges_list = [] - edge_features_list = [] - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - edge_feature = bond_to_feature_vector(bond) - - edges_list.append((i, j)) - edge_features_list.append(edge_feature) - edges_list.append((j, i)) - edge_features_list.append(edge_feature) - - # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] - edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) - - # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] - edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - - return data - - -def graph_data_obj_to_nx_simple(data): - """ torch geometric -> networkx - NB: possible issues with recapitulating relative - stereochemistry since the edges in the nx object are unordered. - :param data: pytorch geometric Data object - :return: networkx object """ - G = nx.Graph() - - # atoms - atom_features = data.x.cpu().numpy() - num_atoms = atom_features.shape[0] - for i in range(num_atoms): - temp_feature = atom_features[i] - G.add_node( - i, - x0=temp_feature[0], - x1=temp_feature[1], - x2=temp_feature[2], - x3=temp_feature[3], - x4=temp_feature[4], - x5=temp_feature[5], - x6=temp_feature[6], - x7=temp_feature[7], - x8=temp_feature[8]) - pass - - # bonds - edge_index = data.edge_index.cpu().numpy() - edge_attr = data.edge_attr.cpu().numpy() - num_bonds = edge_index.shape[1] - for j in range(0, num_bonds, 2): - begin_idx = int(edge_index[0, j]) - end_idx = int(edge_index[1, j]) - temp_feature= edge_attr[j] - if not G.has_edge(begin_idx, end_idx): - G.add_edge(begin_idx, end_idx, - e0=temp_feature[0], - e1=temp_feature[1], - e2=temp_feature[2]) - - return G - - -def nx_to_graph_data_obj_simple(G): - """ vice versa of graph_data_obj_to_nx_simple() - Assume node indices are numbered from 0 to num_nodes - 1. - NB: Uses simplified atom and bond features, and represent as indices. - NB: possible issues with recapitulating relative stereochemistry - since the edges in the nx object are unordered. """ - - # atoms - # num_atom_features = 2 # atom type, chirality tag - atom_features_list = [] - for _, node in G.nodes(data=True): - atom_feature = [node['x0'], node['x1'], node['x2'], node['x3'], node['x4'], node['x5'], node['x6'], node['x7'], node['x8']] - atom_features_list.append(atom_feature) - x = torch.tensor(np.array(atom_features_list), dtype=torch.long) - - # bonds - num_bond_features = 3 # bond type, bond direction - if len(G.edges()) > 0: # mol has bonds - edges_list = [] - edge_features_list = [] - for i, j, edge in G.edges(data=True): - edge_feature = [edge['e0'], edge['e1'], edge['e2']] - edges_list.append((i, j)) - edge_features_list.append(edge_feature) - edges_list.append((j, i)) - edge_features_list.append(edge_feature) - - # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] - edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) - - # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] - edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) - else: # mol has no bonds - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - - return data - - -def create_standardized_mol_id(smiles): - """ smiles -> inchi """ - - if check_smiles_validity(smiles): - # remove stereochemistry - smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), - isomericSmiles=False) - mol = AllChem.MolFromSmiles(smiles) - if mol is not None: - # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)\ - # c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 - if '.' in smiles: # if multiple species, pick largest molecule - mol_species_list = split_rdkit_mol_obj(mol) - largest_mol = get_largest_mol(mol_species_list) - inchi = AllChem.MolToInchi(largest_mol) - else: - inchi = AllChem.MolToInchi(mol) - return inchi - return +from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip +from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple class MoleculeNetGraphDataset(InMemoryDataset): diff --git a/MoleculeSTM/datasets/utils.py b/MoleculeSTM/datasets/utils.py index 38446aa..2f9b1e4 100644 --- a/MoleculeSTM/datasets/utils.py +++ b/MoleculeSTM/datasets/utils.py @@ -4,60 +4,26 @@ import torch from rdkit import Chem from torch_geometric.data import Data - - -allowable_features = { - 'possible_atomic_num_list': list(range(1, 119)), - 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], - 'possible_chirality_list': [ - Chem.rdchem.ChiralType.CHI_UNSPECIFIED, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, - Chem.rdchem.ChiralType.CHI_OTHER - ], - 'possible_hybridization_list': [ - Chem.rdchem.HybridizationType.S, - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - Chem.rdchem.HybridizationType.UNSPECIFIED - ], - 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8], - 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], - 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - 'possible_bonds': [ - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC - ], - 'possible_bond_dirs': [ # only for double bond stereo information - Chem.rdchem.BondDir.NONE, - Chem.rdchem.BondDir.ENDUPRIGHT, - Chem.rdchem.BondDir.ENDDOWNRIGHT - ] -} +from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector def mol_to_graph_data_obj_simple(mol): + """ used in MoleculeNetGraphDataset() class + Converts rdkit mol objects to graph data object in pytorch geometric + NB: Uses simplified atom and bond features, and represent as indices + :param mol: rdkit mol object + :return: graph data object with the attributes: x, edge_index, edge_attr """ + # atoms - # num_atom_features = 2 # atom type, chirality tag atom_features_list = [] for atom in mol.GetAtoms(): - atomic_num = atom.GetAtomicNum() - chiral_tag = atom.GetChiralTag() - if atomic_num == 0: - atomic_num = 118 # Only for one extreme case - atom_feature = [allowable_features['possible_atomic_num_list'].index(atomic_num)] + \ - [allowable_features['possible_chirality_list'].index(chiral_tag)] + atom_feature = atom_to_feature_vector(atom) atom_features_list.append(atom_feature) x = torch.tensor(np.array(atom_features_list), dtype=torch.long) # bonds if len(mol.GetBonds()) <= 0: # mol has no bonds - num_bond_features = 2 # bond type & direction + num_bond_features = 3 # bond type & direction edge_index = torch.empty((2, 0), dtype=torch.long) edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) else: # mol has bonds @@ -66,12 +32,8 @@ def mol_to_graph_data_obj_simple(mol): for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() - bond_type = bond.GetBondType() - bond_dir = bond.GetBondDir() - if bond_dir not in allowable_features['possible_bond_dirs']: - bond_dir = 0 - edge_feature = [allowable_features['possible_bonds'].index(bond_type)] + \ - [allowable_features['possible_bond_dirs'].index(bond_dir)] + edge_feature = bond_to_feature_vector(bond) + edges_list.append((i, j)) edge_features_list.append(edge_feature) edges_list.append((j, i))