Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About mol_to_graph_data_obj_simple functions #24

Open
lhkhiem28 opened this issue Apr 19, 2024 · 2 comments
Open

About mol_to_graph_data_obj_simple functions #24

lhkhiem28 opened this issue Apr 19, 2024 · 2 comments

Comments

@lhkhiem28
Copy link

Thank you for an interesting repo.

I went through the code and I noticed that you used two different mol_to_graph_data_obj_simple functions for contrastive pre-training and property prediction fine-tuning.
pre-training: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/utils.py#L44
fine-tuning: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/MoleculeNet_Graph.py#L17

Could you explain why we have to do that? While you used the same GNN architecture for pre-training and fine-tuning, does using different mol_to_graph_data_obj_simple functions affect the GNN's behavior?

Looking forward to hearing from you soon.

Thanks.

@chao1224
Copy link
Owner

Hi @lhkhiem28,

Thank you for checking this, and please use the OGB version for the featurization.

We tested both versions, as we did in GraphMVP. I merged the wrong version for the previous code release.

chao1224 added a commit that referenced this issue Apr 19, 2024
chao1224 added a commit that referenced this issue Apr 19, 2024
@lhkhiem28
Copy link
Author

lhkhiem28 commented Apr 19, 2024

So, the function in utils.py is correct?
And is that function aligned with the checkpoint here https://huggingface.co/chao1224/MoleculeSTM/tree/main/pretrained_MoleculeSTM/SciBERT-Graph-3e-5-1-1e-4-1-InfoNCE-0.1-32-32

@chao1224 Can you confirm?

def mol_to_graph_data_obj_simple(mol):
    """ used in MoleculeNetGraphDataset() class
    Converts rdkit mol objects to graph data object in pytorch geometric
    NB: Uses simplified atom and bond features, and represent as indices
    :param mol: rdkit mol object
    :return: graph data object with the attributes: x, edge_index, edge_attr """

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = atom_to_feature_vector(atom)
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds
    if len(mol.GetBonds()) <= 0:  # mol has no bonds
        num_bond_features = 3  # bond type & direction
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
    else:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = bond_to_feature_vector(bond)

            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants