Skip to content

Commit de9bc40

Browse files
authored
Export to pytorch geometric (#875)
1 parent 709b2c6 commit de9bc40

File tree

4 files changed

+103
-4
lines changed

4 files changed

+103
-4
lines changed

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,10 @@ def get_tag(self):
10181018
"plotly>=5.3.0",
10191019
"Pillow>=9; platform_python_implementation != 'PyPy'",
10201020
],
1021+
"test-pyg": [
1022+
"torch>=2.0.0; platform_python_implementation != 'PyPy'",
1023+
"torch-geometric>=2.0.0; platform_python_implementation != 'PyPy'",
1024+
],
10211025
# Dependencies needed for testing on Windows ARM64; only those that are either
10221026
# pure Python or have Windows ARM64 wheels as we don't want to compile wheels
10231027
# in CI

src/igraph/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@
211211
_export_graph_to_networkx,
212212
_construct_graph_from_graph_tool,
213213
_export_graph_to_graph_tool,
214+
_export_graph_to_torch_geometric,
214215
)
215216
from igraph.io.random import (
216217
_construct_random_geometric_graph,
@@ -463,6 +464,8 @@ def __init__(self, *args, **kwds):
463464
from_graph_tool = classmethod(_construct_graph_from_graph_tool)
464465
to_graph_tool = _export_graph_to_graph_tool
465466

467+
to_torch_geometric = _export_graph_to_torch_geometric
468+
466469
# Files
467470
Read_DIMACS = classmethod(_construct_graph_from_dimacs_file)
468471
write_dimacs = _write_graph_to_dimacs_file
@@ -708,7 +711,9 @@ def es(self):
708711

709712
###########################
710713
# Paths/traversals
711-
def get_all_simple_paths(self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None):
714+
def get_all_simple_paths(
715+
self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None
716+
):
712717
"""Calculates all the simple paths from a given node to some other nodes
713718
(or all of them) in a graph.
714719
@@ -973,15 +978,14 @@ def Incidence(cls, *args, **kwds):
973978
def are_connected(self, *args, **kwds):
974979
"""Deprecated alias to L{Graph.are_adjacent()}."""
975980
deprecated(
976-
"Graph.are_connected() is deprecated; use Graph.are_adjacent() " "instead"
981+
"Graph.are_connected() is deprecated; use Graph.are_adjacent() instead"
977982
)
978983
return self.are_adjacent(*args, **kwds)
979984

980985
def get_incidence(self, *args, **kwds):
981986
"""Deprecated alias to L{Graph.get_biadjacency()}."""
982987
deprecated(
983-
"Graph.get_incidence() is deprecated; use Graph.get_biadjacency() "
984-
"instead"
988+
"Graph.get_incidence() is deprecated; use Graph.get_biadjacency() instead"
985989
)
986990
return self.get_biadjacency(*args, **kwds)
987991

src/igraph/io/libraries.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,48 @@ def _construct_graph_from_graph_tool(cls, g):
270270
graph.add_edges(edges, eattr)
271271

272272
return graph
273+
274+
275+
def _export_graph_to_torch_geometric(
276+
graph, vertex_attributes=None, edge_attributes=None
277+
):
278+
"""Converts the graph to torch geometric
279+
280+
Data types: graph-tool only accepts specific data types. See the
281+
following web page for a list:
282+
283+
https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data
284+
285+
@param g: graph-tool Graph
286+
@param vertex_attributes: dictionary of vertex attributes to transfer.
287+
Keys are attributes from the vertices, values are data types (see
288+
below). C{None} means no vertex attributes are transferred.
289+
@param edge_attributes: dictionary of edge attributes to transfer.
290+
Keys are attributes from the edges, values are data types (see
291+
below). C{None} means no vertex attributes are transferred.
292+
"""
293+
import torch
294+
from torch_geometric.data import Data
295+
296+
if vertex_attributes is None:
297+
vertex_attributes = graph.vertex_attributes()
298+
if edge_attributes is None:
299+
edge_attributes = graph.edge_attributes()
300+
301+
# Edge index
302+
edge_index = torch.tensor(graph.get_edgelist(), dtype=torch.long)
303+
304+
# Node attributes
305+
x = torch.tensor([graph.vs[attr] for attr in vertex_attributes])
306+
if x.ndim > 1:
307+
x = x.permute(*torch.arange(x.ndim - 1, -1, -1))
308+
309+
# Edge attributes
310+
edge_attr = torch.tensor([graph.es[attr] for attr in edge_attributes])
311+
if edge_attr.ndim > 1:
312+
edge_attr = edge_attr.permute(*torch.arange(edge_attr.ndim - 1, -1, -1))
313+
314+
# Wrap into correct data structure
315+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
316+
317+
return data

tests/test_foreign.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
pd = None
2424

2525

26+
try:
27+
import torch
28+
from torch_geometric.data import Data as PyGData
29+
except ImportError:
30+
torch = None
31+
PyGData = None
32+
33+
2634
GRAPHML_EXAMPLE_FILE = """\
2735
<?xml version="1.0" encoding="UTF-8"?>
2836
<graphml xmlns="http://graphml.graphdrawing.org/xmlns"
@@ -821,6 +829,44 @@ def testGraphGraphTool(self):
821829
self.assertEqual(g.vcount(), g2.vcount())
822830
self.assertEqual(sorted(g.get_edgelist()), sorted(g2.get_edgelist()))
823831

832+
@unittest.skipIf(PyGData is None, "test case depends on torch_geometric")
833+
def testGraphTorchGeometric(self):
834+
# Undirected
835+
g = Graph.Ring(10)
836+
g.vs["vattr"] = list(range(g.vcount()))
837+
g.es["eattr"] = list(range(len(g.es)))
838+
839+
# Go to torch geometric
840+
data_pyg = g.to_torch_geometric()
841+
842+
self.assertEqual(g.vcount(), data_pyg.num_nodes)
843+
self.assertEqual(
844+
sorted([list(x) for x in g.get_edgelist()]),
845+
sorted(data_pyg.edge_index.tolist()),
846+
)
847+
848+
# Test attributes
849+
self.assertEqual(g.vs["vattr"], data_pyg.x[:, 0].tolist())
850+
self.assertEqual(g.es["eattr"], data_pyg.edge_attr[:, 0].tolist())
851+
852+
# Directed
853+
g = Graph.Ring(10, directed=True)
854+
g.vs["vattr"] = list(range(g.vcount()))
855+
g.es["eattr"] = list(range(len(g.es)))
856+
857+
# Go to torch geometric
858+
data_pyg = g.to_torch_geometric()
859+
860+
self.assertEqual(g.vcount(), data_pyg.num_nodes)
861+
self.assertEqual(
862+
sorted([list(x) for x in g.get_edgelist()]),
863+
sorted(data_pyg.edge_index.tolist()),
864+
)
865+
866+
# Test attributes
867+
self.assertEqual(g.vs["vattr"], data_pyg.x[:, 0].tolist())
868+
self.assertEqual(g.es["eattr"], data_pyg.edge_attr[:, 0].tolist())
869+
824870
@unittest.skipIf(gt is None, "test case depends on graph-tool")
825871
def testMultigraphGraphTool(self):
826872
# Undirected

0 commit comments

Comments
 (0)