Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9150bb6
Create .gitignore
aditya0by0 May 7, 2025
06a71a6
update precommit + github action
aditya0by0 May 7, 2025
ff1adc9
pre-commit format files
aditya0by0 May 7, 2025
d7f30d3
change graph from directed to UNDIRECTED
aditya0by0 May 7, 2025
b8189d1
add test data
aditya0by0 May 14, 2025
ad301e6
edge_features should be calculated after undirected graph
aditya0by0 May 14, 2025
344d828
directed edge which form an un-dir edge should be adjancent
aditya0by0 May 14, 2025
8a69828
add test for GraphPropertyReader
aditya0by0 May 14, 2025
e0064b8
add gt test data for aspirin
aditya0by0 May 14, 2025
a9c7228
Update test_data.py
aditya0by0 May 14, 2025
0a9760d
add more graph test
aditya0by0 May 14, 2025
1a8dcb6
first src to tgt edges then tgt to src
aditya0by0 May 14, 2025
5d4c174
add test for duplicate directed edges
aditya0by0 May 14, 2025
945ef7c
restore import
aditya0by0 May 14, 2025
53a240a
concat edge attr for undirected graph
aditya0by0 May 14, 2025
b1f2da3
concat prop values instead of edge_attr
aditya0by0 May 15, 2025
3615fb1
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 May 21, 2025
7a8b664
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 May 21, 2025
53ca438
inherit from ChebiOverX
aditya0by0 May 23, 2025
4319e47
`nan_to_num` numpy2.x compatibility
aditya0by0 May 25, 2025
7c0a484
add print statements
aditya0by0 May 25, 2025
ffc0b75
remove print for dataloader phase
aditya0by0 May 25, 2025
0a39749
Update .gitignore
aditya0by0 May 28, 2025
8e56851
merge from dev
aditya0by0 Jul 24, 2025
f1e3eb6
fix import error
aditya0by0 Jul 25, 2025
b9e081e
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 Oct 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import importlib
import os
from abc import ABC
from typing import Callable, List, Optional

import pandas as pd
import torch
import tqdm
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import (
ChEBIOver50,
ChEBIOver100,
ChEBIOverX,
ChEBIOverXPartial,
)
from lightning_utilities.core.rank_zero import rank_zero_info
Expand Down Expand Up @@ -48,7 +49,7 @@ def _resolve_property(
return getattr(graph_properties, property)()


class GraphPropertiesMixIn(XYBaseDataModule):
class GraphPropertiesMixIn(ChEBIOverX, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for subclassing ChEBIOverX instead of XYBaseDataModule? This is not specific to ChEBI and we might want to use it for PubChem data as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn’t notice that a class from chebai_graph/preprocessing/datasets/pubchem.py uses a class from chebi.py. Ideally, all classes defined in chebi.py should be specific to the ChEBI dataset. For functionalities shared across multiple datasets, we can introduce a base dataset class within this repository (similar to the structure used in the chebai repository).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is confusing. We should move GraphPropertiesMixIn (and everything related to properties) to a different file.

READER = GraphPropertyReader

def __init__(
Expand Down Expand Up @@ -107,10 +108,12 @@ def enc_if_not_none(encode, value):
if not os.path.isfile(self.get_property_path(property)):
rank_zero_info(f"Processing property {property.name}")
# read all property values first, then encode
rank_zero_info("\tReading property valeus...")
property_values = [
self.reader.read_property(feat, property)
for feat in tqdm.tqdm(features)
]
rank_zero_info("\tEncoding property values...")
property.encoder.on_start(property_values=property_values)
encoded_values = [
enc_if_not_none(property.encoder.encode, value)
Expand Down Expand Up @@ -166,7 +169,11 @@ def _merge_props_into_base(self, row):
if isinstance(property, AtomProperty):
x = torch.cat([x, property_values], dim=1)
elif isinstance(property, BondProperty):
edge_attr = torch.cat([edge_attr, property_values], dim=1)
# Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges
edge_attr = torch.cat(
[edge_attr, torch.cat([property_values, property_values], dim=0)],
dim=1,
)
else:
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
return GeomData(
Expand Down
16 changes: 8 additions & 8 deletions chebai_graph/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def _read_data(self, raw_data):

x = torch.zeros((mol.GetNumAtoms(), 0))

edge_attr = torch.zeros((mol.GetNumBonds(), 0))

edge_index = torch.tensor(
[
[bond.GetBeginAtomIdx() for bond in mol.GetBonds()],
[bond.GetEndAtomIdx() for bond in mol.GetBonds()],
]
)
# First source to target edges, then target to source edges
src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()]
tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()]
edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long)

# edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features]
edge_attr = torch.zeros((edge_index.size(1), 0))

return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr)

def on_finish(self):
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/unit/__init__.py
Empty file.
Empty file added tests/unit/readers/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions tests/unit/readers/testGraphPropertyReader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import unittest

import torch
from torch_geometric.data import Data as GeomData

from chebai_graph.preprocessing.reader import GraphPropertyReader
from tests.unit.test_data import MoleculeGraph


class TestGraphPropertyReader(unittest.TestCase):
"""Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects."""

def setUp(self) -> None:
"""Initialize the reader and the reference molecule graph."""
self.reader: GraphPropertyReader = GraphPropertyReader()
self.molecule_graph: MoleculeGraph = MoleculeGraph()

def test_read_data(self) -> None:
"""Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure."""
smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin

data: GeomData = self.reader._read_data(smiles) # noqa

self.assertIsInstance(
data,
GeomData,
msg="The output should be an instance of torch_geometric.data.Data.",
)

self.assertEqual(
data.edge_index.shape[0],
2,
msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}",
)

self.assertEqual(
data.edge_index.shape[1],
data.edge_attr.shape[0],
msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})",
)

self.assertEqual(
len(set(data.edge_index[0].tolist())),
data.x.shape[0],
msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})",
)

# Check for duplicates by checking if the rows are the same (direction matters)
_, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True)
self.assertFalse(
torch.any(counts > 1),
msg="There are duplicates of directed edge in edge_index",
)

expected_data: GeomData = self.molecule_graph.get_aspirin_graph()
self.assertTrue(
torch.equal(data.edge_index, expected_data.edge_index),
msg=(
"edge_index tensors do not match.\n"
f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n"
f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}"
f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6"
),
)

self.assertEqual(
data.x.shape[0],
expected_data.x.shape[0],
msg=(
"The number of atoms (nodes) in the parsed graph does not match the reference graph.\n"
f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}"
),
)

self.assertEqual(
data.edge_attr.shape[0],
expected_data.edge_attr.shape[0],
msg=(
"The number of edge attributes does not match the expected value.\n"
f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}"
),
)


if __name__ == "__main__":
unittest.main()
102 changes: 102 additions & 0 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
from torch_geometric.data import Data


class MoleculeGraph:
"""Class representing molecular graph data."""

def get_aspirin_graph(self):
"""
Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin.

Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365

Node labels (atom indices):
O2 C5———C6
\ / \
C1———O3———C4 C7
/ \ /
C0 C9———C8
/
C10
/ \
O12 O11


Returns:
torch_geometric.data.Data: A Data object with attributes:
- x (FloatTensor): Node feature matrix of shape (num_nodes, 1).
- edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges).
- edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1).

Refer:
For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html
"""

# --- Node features: atomic numbers (C=6, O=8) ---
# Shape of x : num_nodes x num_of_node_features
# fmt: off
x = torch.tensor(
[
[6], # C0 - This feature belongs to atom/node with 0 value in edge_index
[6], # C1 - This feature belongs to atom/node with 1 value in edge_index
[8], # O2 - This feature belongs to atom/node with 2 value in edge_index
[8], # O3 - This feature belongs to atom/node with 3 value in edge_index
[6], # C4 - This feature belongs to atom/node with 4 value in edge_index
[6], # C5 - This feature belongs to atom/node with 5 value in edge_index
[6], # C6 - This feature belongs to atom/node with 6 value in edge_index
[6], # C7 - This feature belongs to atom/node with 7 value in edge_index
[6], # C8 - This feature belongs to atom/node with 8 value in edge_index
[6], # C9 - This feature belongs to atom/node with 9 value in edge_index
[6], # C10 - This feature belongs to atom/node with 10 value in edge_index
[8], # O11 - This feature belongs to atom/node with 11 value in edge_index
[8], # O12 - This feature belongs to atom/node with 12 value in edge_index
],
dtype=torch.float,
)
# fmt: on

# --- Edge list (bidirectional) ---
# Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26)
# Generated using RDKIT 2024.9.6
# fmt: off
_edge_index = torch.tensor([
[0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v)
], dtype=torch.long)
# fmt: on

# Reverse the edges
reversed_edge_index = _edge_index[[1, 0], :]

# First all directed edges from source to target are placed,
# then all directed edges from target to source are placed --- this is needed
undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1)

# --- Dummy edge features ---
# Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1)
# fmt: off
_edge_attr = torch.tensor([
[1], # C0 - C1, This two features belong to elements at index 0 in `edge_index`
[2], # C1 - C2, This two features belong to elements at index 1 in `edge_index`
[2], # C1 - O3, This two features belong to elements at index 2 in `edge_index`
[2], # O3 - C4, This two features belong to elements at index 3 in `edge_index`
[1], # C4 - C5, This two features belong to elements at index 4 in `edge_index`
[1], # C5 - C6, This two features belong to elements at index 5 in `edge_index`
[1], # C6 - C7, This two features belong to elements at index 6 in `edge_index`
[1], # C7 - C8, This two features belong to elements at index 7 in `edge_index`
[1], # C8 - C9, This two features belong to elements at index 8 in `edge_index`
[1], # C9 - C10, This two features belong to elements at index 9 in `edge_index`
[1], # C10 - O11, This two features belong to elements at index 10 in `edge_index`
[1], # C10 - O12, This two features belong to elements at index 11 in `edge_index`
[1], # C9 - C4, This two features belong to elements at index 12 in `edge_index`
], dtype=torch.float)
# fmt: on

# Alignement of edge attributes should in same order as of edge_index
undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0)

# Create graph data object
return Data(
x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr
)