-
Notifications
You must be signed in to change notification settings - Fork 1
Fix to make Directed graph to Undirected #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+206
−11
Closed
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
9150bb6
Create .gitignore
aditya0by0 06a71a6
update precommit + github action
aditya0by0 ff1adc9
pre-commit format files
aditya0by0 d7f30d3
change graph from directed to UNDIRECTED
aditya0by0 b8189d1
add test data
aditya0by0 ad301e6
edge_features should be calculated after undirected graph
aditya0by0 344d828
directed edge which form an un-dir edge should be adjancent
aditya0by0 8a69828
add test for GraphPropertyReader
aditya0by0 e0064b8
add gt test data for aspirin
aditya0by0 a9c7228
Update test_data.py
aditya0by0 0a9760d
add more graph test
aditya0by0 1a8dcb6
first src to tgt edges then tgt to src
aditya0by0 5d4c174
add test for duplicate directed edges
aditya0by0 945ef7c
restore import
aditya0by0 53a240a
concat edge attr for undirected graph
aditya0by0 b1f2da3
concat prop values instead of edge_attr
aditya0by0 3615fb1
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 7a8b664
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 53ca438
inherit from ChebiOverX
aditya0by0 4319e47
`nan_to_num` numpy2.x compatibility
aditya0by0 7c0a484
add print statements
aditya0by0 ffc0b75
remove print for dataloader phase
aditya0by0 0a39749
Update .gitignore
aditya0by0 8e56851
merge from dev
aditya0by0 f1e3eb6
fix import error
aditya0by0 b9e081e
Merge branch 'dev' into fix/directed-to-undirected-graph
aditya0by0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import unittest | ||
|
||
import torch | ||
from torch_geometric.data import Data as GeomData | ||
|
||
from chebai_graph.preprocessing.reader import GraphPropertyReader | ||
from tests.unit.test_data import MoleculeGraph | ||
|
||
|
||
class TestGraphPropertyReader(unittest.TestCase): | ||
"""Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" | ||
|
||
def setUp(self) -> None: | ||
"""Initialize the reader and the reference molecule graph.""" | ||
self.reader: GraphPropertyReader = GraphPropertyReader() | ||
self.molecule_graph: MoleculeGraph = MoleculeGraph() | ||
|
||
def test_read_data(self) -> None: | ||
"""Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" | ||
smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin | ||
|
||
data: GeomData = self.reader._read_data(smiles) # noqa | ||
|
||
self.assertIsInstance( | ||
data, | ||
GeomData, | ||
msg="The output should be an instance of torch_geometric.data.Data.", | ||
) | ||
|
||
self.assertEqual( | ||
data.edge_index.shape[0], | ||
2, | ||
msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", | ||
) | ||
|
||
self.assertEqual( | ||
data.edge_index.shape[1], | ||
data.edge_attr.shape[0], | ||
msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", | ||
) | ||
|
||
self.assertEqual( | ||
len(set(data.edge_index[0].tolist())), | ||
data.x.shape[0], | ||
msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", | ||
) | ||
|
||
# Check for duplicates by checking if the rows are the same (direction matters) | ||
_, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) | ||
self.assertFalse( | ||
torch.any(counts > 1), | ||
msg="There are duplicates of directed edge in edge_index", | ||
) | ||
|
||
expected_data: GeomData = self.molecule_graph.get_aspirin_graph() | ||
self.assertTrue( | ||
torch.equal(data.edge_index, expected_data.edge_index), | ||
msg=( | ||
"edge_index tensors do not match.\n" | ||
f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" | ||
f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" | ||
f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" | ||
), | ||
) | ||
|
||
self.assertEqual( | ||
data.x.shape[0], | ||
expected_data.x.shape[0], | ||
msg=( | ||
"The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" | ||
f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" | ||
), | ||
) | ||
|
||
self.assertEqual( | ||
data.edge_attr.shape[0], | ||
expected_data.edge_attr.shape[0], | ||
msg=( | ||
"The number of edge attributes does not match the expected value.\n" | ||
f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" | ||
), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
from torch_geometric.data import Data | ||
|
||
|
||
class MoleculeGraph: | ||
"""Class representing molecular graph data.""" | ||
|
||
def get_aspirin_graph(self): | ||
""" | ||
Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. | ||
|
||
Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 | ||
|
||
Node labels (atom indices): | ||
O2 C5———C6 | ||
\ / \ | ||
C1———O3———C4 C7 | ||
/ \ / | ||
C0 C9———C8 | ||
/ | ||
C10 | ||
/ \ | ||
O12 O11 | ||
|
||
|
||
Returns: | ||
torch_geometric.data.Data: A Data object with attributes: | ||
- x (FloatTensor): Node feature matrix of shape (num_nodes, 1). | ||
- edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). | ||
- edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). | ||
|
||
Refer: | ||
For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html | ||
""" | ||
|
||
# --- Node features: atomic numbers (C=6, O=8) --- | ||
# Shape of x : num_nodes x num_of_node_features | ||
# fmt: off | ||
x = torch.tensor( | ||
[ | ||
[6], # C0 - This feature belongs to atom/node with 0 value in edge_index | ||
[6], # C1 - This feature belongs to atom/node with 1 value in edge_index | ||
[8], # O2 - This feature belongs to atom/node with 2 value in edge_index | ||
[8], # O3 - This feature belongs to atom/node with 3 value in edge_index | ||
[6], # C4 - This feature belongs to atom/node with 4 value in edge_index | ||
[6], # C5 - This feature belongs to atom/node with 5 value in edge_index | ||
[6], # C6 - This feature belongs to atom/node with 6 value in edge_index | ||
[6], # C7 - This feature belongs to atom/node with 7 value in edge_index | ||
[6], # C8 - This feature belongs to atom/node with 8 value in edge_index | ||
[6], # C9 - This feature belongs to atom/node with 9 value in edge_index | ||
[6], # C10 - This feature belongs to atom/node with 10 value in edge_index | ||
[8], # O11 - This feature belongs to atom/node with 11 value in edge_index | ||
[8], # O12 - This feature belongs to atom/node with 12 value in edge_index | ||
], | ||
dtype=torch.float, | ||
) | ||
# fmt: on | ||
|
||
# --- Edge list (bidirectional) --- | ||
# Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) | ||
# Generated using RDKIT 2024.9.6 | ||
# fmt: off | ||
_edge_index = torch.tensor([ | ||
[0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) | ||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) | ||
], dtype=torch.long) | ||
# fmt: on | ||
|
||
# Reverse the edges | ||
reversed_edge_index = _edge_index[[1, 0], :] | ||
|
||
# First all directed edges from source to target are placed, | ||
# then all directed edges from target to source are placed --- this is needed | ||
undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) | ||
|
||
# --- Dummy edge features --- | ||
# Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) | ||
# fmt: off | ||
_edge_attr = torch.tensor([ | ||
[1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` | ||
[2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` | ||
[2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` | ||
[2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` | ||
[1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` | ||
[1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` | ||
[1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` | ||
[1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` | ||
[1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` | ||
[1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` | ||
[1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` | ||
[1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` | ||
[1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` | ||
], dtype=torch.float) | ||
# fmt: on | ||
|
||
# Alignement of edge attributes should in same order as of edge_index | ||
undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) | ||
|
||
# Create graph data object | ||
return Data( | ||
x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason for subclassing ChEBIOverX instead of XYBaseDataModule? This is not specific to ChEBI and we might want to use it for PubChem data as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn’t notice that a class from
chebai_graph/preprocessing/datasets/pubchem.py
uses a class fromchebi.py
. Ideally, all classes defined inchebi.py
should be specific to the ChEBI dataset. For functionalities shared across multiple datasets, we can introduce a base dataset class within this repository (similar to the structure used in the chebai repository).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is confusing. We should move GraphPropertiesMixIn (and everything related to properties) to a different file.