Skip to content

Commit 4675350

Browse files
committed
Merge branch 'theo-long-configuration-complex'
2 parents 776686b + cefd9e4 commit 4675350

File tree

9 files changed

+1616
-1
lines changed

9 files changed

+1616
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
data_domain: graph
2+
data_type: toy_dataset
3+
data_name: simple_configuration_graphs
4+
data_dir: datasets/${data_domain}/${data_type}
5+
6+
# Dataset parameters
7+
num_features: 1
8+
num_classes: 2
9+
task: classification
10+
loss_type: cross_entropy
11+
monitor_metric: accuracy
12+
task_level: graph
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
transform_type: 'lifting'
2+
transform_name: "DiscreteConfigurationComplexLifting"
3+
k: 2
4+
feature_aggregation: "concat"
5+
preserve_edge_attr: True
6+
feature_lifting: ProjectionSum

modules/data/load/loaders.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
load_random_points,
3232
load_random_shape_point_cloud,
3333
load_senate_committee,
34+
load_simple_configuration_graphs,
3435
load_simplicial_dataset,
3536
)
3637

@@ -153,6 +154,9 @@ def load(self) -> torch_geometric.data.Dataset:
153154
elif self.parameters.data_name in ["manual_rings"]:
154155
data = load_manual_mol()
155156
dataset = CustomDataset([data], self.data_dir)
157+
elif self.parameters.data_name in ["simple_configuration_graphs"]:
158+
data = load_simple_configuration_graphs()
159+
dataset = CustomDataset([*data], self.data_dir)
156160

157161
else:
158162
raise NotImplementedError(

modules/data/utils/utils.py

+39
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,45 @@ def load_senate_committee(
973973
)
974974

975975

976+
def load_simple_configuration_graphs():
977+
"""Generate small graphs to illustrate the discrete configuration complex."""
978+
979+
# Y shaped graph
980+
y_graph = nx.Graph()
981+
y_graph.add_edges_from([(0, 1), (0, 2), (0, 3)])
982+
y_data = torch_geometric.data.Data(
983+
x=torch.tensor([0, 1, 2, 3]).unsqueeze(1).float(),
984+
y=torch.tensor([0]),
985+
edge_index=torch.Tensor(list(y_graph.edges())).T.long(),
986+
num_nodes=4,
987+
edge_attr=torch.Tensor([-1, -2, -3]).unsqueeze(1).float(),
988+
)
989+
990+
# X shaped graph
991+
x_graph = nx.Graph()
992+
x_graph.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
993+
x_data = torch_geometric.data.Data(
994+
x=torch.tensor([0, 1, 2, 3, 4]).unsqueeze(1).float(),
995+
y=torch.tensor([0]),
996+
edge_index=torch.Tensor(list(x_graph.edges())).T.long(),
997+
num_nodes=4,
998+
edge_attr=torch.Tensor([-1, -2, -3, -4]).unsqueeze(1).float(),
999+
)
1000+
1001+
# g shaped graph
1002+
g_graph = nx.Graph()
1003+
g_graph.add_edges_from([(0, 1), (1, 2), (2, 0), (2, 3)])
1004+
g_data = torch_geometric.data.Data(
1005+
x=torch.tensor([0, 1, 2, 3]).unsqueeze(1).float(),
1006+
y=torch.tensor([1]),
1007+
edge_index=torch.Tensor(list(g_graph.edges())).T.long(),
1008+
num_nodes=4,
1009+
edge_attr=torch.Tensor([-1, -2, -3, -4]).unsqueeze(1).float(),
1010+
)
1011+
1012+
return x_data, y_data, g_data
1013+
1014+
9761015
def get_Planetoid_pyg(cfg):
9771016
r"""Loads Planetoid graph datasets from torch_geometric.
9781017

modules/transforms/data_transform.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from modules.transforms.liftings.graph2cell.cycle_lifting import (
1212
CellCycleLifting,
1313
)
14+
from modules.transforms.liftings.graph2cell.discrete_configuration_complex_lifting import (
15+
DiscreteConfigurationComplexLifting,
16+
)
1417
from modules.transforms.liftings.graph2combinatorial.curve_lifting import (
1518
CurveLifting,
1619
)
@@ -95,6 +98,7 @@
9598
"LatentCliqueLifting": LatentCliqueLifting,
9699
# Graph -> Cell Complex
97100
"CellCycleLifting": CellCycleLifting,
101+
"DiscreteConfigurationComplexLifting": DiscreteConfigurationComplexLifting,
98102
# Graph -> Combinatorial Complex
99103
"CombinatorialRingCloseAtomsLifting": CombinatorialRingCloseAtomsLifting,
100104
"CurveLifting": CurveLifting,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from itertools import permutations
2+
from typing import ClassVar
3+
4+
import networkx as nx
5+
import torch
6+
import torch_geometric
7+
from toponetx.classes import CellComplex
8+
9+
from modules.transforms.liftings.graph2cell.base import Graph2CellLifting
10+
from modules.utils.utils import edge_cycle_to_vertex_cycle
11+
12+
Vertex = int
13+
Edge = tuple[Vertex, Vertex]
14+
ConfigurationTuple = tuple[Vertex | Edge]
15+
16+
17+
class DiscreteConfigurationComplexLifting(Graph2CellLifting):
18+
r"""Lifts graphs to cell complexes by generating the k-th *discrete configuration complex* $D_k(G)$ of the graph. This is a cube complex, which is similar to a simplicial complex except each n-dimensional cell is homeomorphic to a n-dimensional cube rather than an n-dimensional simplex.
19+
20+
The discrete configuration complex of order k consists of all sets of k unique edges or vertices of $G$, with the additional constraint that if an edge e is in a cell, then neither of the endpoints of e are in the cell. For examples of different graphs and their configuration complexes, see the tutorial.
21+
22+
Note that since TopoNetx only supports cell complexes of dimension 2, if you generate a configuration complex of order k > 2 this will only produce the 2-skeleton.
23+
24+
Parameters
25+
----------
26+
k: int,
27+
The order of the configuration complex, i.e. the number of 'agents' in a single configuration.
28+
preserve_edge_attr : bool, optional
29+
Whether to preserve edge attributes. Default is True.
30+
feature_aggregation: str, optional
31+
For a k-agent configuration, the method by which the features are aggregated. Can be "mean", "sum", or "concat". Default is "concat".
32+
**kwargs : optional
33+
Additional arguments for the class.
34+
"""
35+
36+
def __init__(
37+
self,
38+
k: int,
39+
preserve_edge_attr: bool = True,
40+
feature_aggregation="concat",
41+
**kwargs,
42+
):
43+
self.k = k
44+
self.complex_dim = 2
45+
if feature_aggregation not in ["mean", "sum", "concat"]:
46+
raise ValueError(
47+
"feature_aggregation must be one of 'mean', 'sum', 'concat'"
48+
)
49+
self.feature_aggregation = feature_aggregation
50+
super().__init__(preserve_edge_attr=preserve_edge_attr, **kwargs)
51+
52+
def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
53+
r"""Applies the full lifting (topology + features) to the input data.
54+
55+
Parameters
56+
----------
57+
data : torch_geometric.data.Data
58+
The input data to be lifted.
59+
60+
Returns
61+
-------
62+
torch_geometric.data.Data
63+
The lifted data.
64+
"""
65+
# Unlike the base class, we do not pass the initial data to the final data
66+
# This is because the configuration complex has a completely different 1-skeleton from the original graph
67+
lifted_topology = self.lift_topology(data)
68+
lifted_topology = self.feature_lifting(lifted_topology)
69+
return torch_geometric.data.Data(y=data.y, **lifted_topology)
70+
71+
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
72+
r"""Generates the cubical complex of discrete graph configurations.
73+
74+
Parameters
75+
----------
76+
data : torch_geometric.data.Data
77+
The input data to be lifted.
78+
79+
Returns
80+
-------
81+
dict
82+
The lifted topology.
83+
"""
84+
G = self._generate_graph_from_data(data)
85+
if G.is_directed():
86+
raise ValueError("Directed Graphs are not supported.")
87+
88+
Configuration = generate_configuration_class(
89+
G, self.feature_aggregation, self.contains_edge_attr
90+
)
91+
92+
# The vertices of the configuration complex are just tuples of k vertices
93+
for dim_0_configuration_tuple in permutations(G, self.k):
94+
configuration = Configuration(dim_0_configuration_tuple)
95+
configuration.generate_upwards_neighbors()
96+
97+
cells = {i: [] for i in range(self.k + 1)}
98+
for conf in Configuration.instances.values():
99+
features = conf.features()
100+
attrs = {"features": features} if features is not None else {}
101+
cell = (conf.contents, attrs)
102+
cells[conf.dim].append(cell)
103+
104+
# TopoNetX only supports cells of dimension <= 2
105+
cc = CellComplex()
106+
for node, attrs in cells[0]:
107+
cc.add_node(node, **attrs)
108+
for edge, attrs in cells[1]:
109+
cc.add_edge(edge[0], edge[1], **attrs)
110+
for cell, attrs in cells[2]:
111+
cell_vertices = edge_cycle_to_vertex_cycle(cell)
112+
cc.add_cell(cell_vertices, rank=2, **attrs)
113+
114+
return self._get_lifted_topology(cc, G)
115+
116+
117+
def generate_configuration_class(
118+
graph: nx.Graph, feature_aggregation: str, edge_features: bool
119+
):
120+
"""Class factory for the Configuration class."""
121+
122+
class Configuration:
123+
"""Represents a single legal configuration of k agents on a graph G. A legal configuration is a tuple of k edges and vertices of G where all the vertices and endpoints are **distinct** i.e. no two edges sharing an endpoint can simultaneously be in the configuration, and adjacent (edge, vertex) pair can be contained in the configuration. Each configuration corresponds to a cell, and the number of edges in the configuration is the dimension.
124+
125+
Parameters
126+
----------
127+
k : int, optional.
128+
The order of the configuration complex, or the number of 'points' in the configuration.
129+
graph: nx.Graph.
130+
The graph on which the configurations are defined.
131+
"""
132+
133+
instances: ClassVar[dict[ConfigurationTuple, "Configuration"]] = {}
134+
135+
def __new__(cls, configuration_tuple: ConfigurationTuple):
136+
# Ensure that a configuration tuple corresponds to a *unique* configuration object
137+
key = configuration_tuple
138+
if key not in cls.instances:
139+
cls.instances[key] = super().__new__(cls)
140+
141+
return cls.instances[key]
142+
143+
def __init__(self, configuration_tuple: ConfigurationTuple) -> None:
144+
# If this object was already initialized earlier, maintain current state
145+
if hasattr(self, "initialized"):
146+
return
147+
148+
self.initialized = True
149+
self.configuration_tuple = configuration_tuple
150+
self.neighborhood = set()
151+
self.dim = 0
152+
for agent in configuration_tuple:
153+
if isinstance(agent, Vertex):
154+
self.neighborhood.add(agent)
155+
else:
156+
self.neighborhood.update(set(agent))
157+
self.dim += 1
158+
159+
if self.dim == 0:
160+
self.contents = configuration_tuple
161+
else:
162+
self.contents = []
163+
164+
self._upwards_neighbors_generated = False
165+
166+
def features(self):
167+
"""Generate the features for the configuration by combining the edge and vertex features."""
168+
features = []
169+
for agent in self.configuration_tuple:
170+
if isinstance(agent, Vertex):
171+
features.append(graph.nodes[agent]["features"])
172+
elif edge_features:
173+
features.append(graph.edges[agent]["features"])
174+
175+
if not features:
176+
return None
177+
178+
if feature_aggregation == "mean":
179+
try:
180+
return torch.stack(features, dim=0).mean(dim=0)
181+
except Exception as e:
182+
raise ValueError(
183+
"Failed to mean feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features."
184+
) from e
185+
elif feature_aggregation == "sum":
186+
try:
187+
return torch.stack(features, dim=0).sum(dim=0)
188+
except Exception as e:
189+
raise ValueError(
190+
"Failed to sum feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features."
191+
) from e
192+
elif feature_aggregation == "concat":
193+
return torch.concatenate(features, dim=-1)
194+
else:
195+
raise ValueError(
196+
f"Unrecognized feature_aggregation: {feature_aggregation}"
197+
)
198+
199+
def generate_upwards_neighbors(self):
200+
"""For the configuration self of dimension d, generate the configurations of dimension d+1 containing it."""
201+
if self._upwards_neighbors_generated:
202+
return
203+
self._upwards_neighbors_generated = True
204+
for i, agent in enumerate(self.configuration_tuple):
205+
if isinstance(agent, Vertex):
206+
for neighbor in graph[agent]:
207+
self._generate_single_neighbor(i, agent, neighbor)
208+
209+
def _generate_single_neighbor(
210+
self, index: int, vertex_agent: int, neighbor: int
211+
):
212+
"""Generate a configuration containing self by moving an agent from a vertex onto an edge."""
213+
# If adding the edge (vertex_agent, neighbor) would produce an illegal configuration, ignore it
214+
if neighbor in self.neighborhood:
215+
return
216+
217+
# We always orient edges (min -> max) to maintain uniqueness of configuration tuples
218+
new_edge = (min(vertex_agent, neighbor), max(vertex_agent, neighbor))
219+
220+
# Remove the vertex at index and replace it with new edge
221+
new_configuration_tuple = (
222+
*self.configuration_tuple[:index],
223+
new_edge,
224+
*self.configuration_tuple[index + 1 :],
225+
)
226+
new_configuration = Configuration(new_configuration_tuple)
227+
new_configuration.contents.append(self.contents)
228+
new_configuration.generate_upwards_neighbors()
229+
230+
return Configuration

modules/utils/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,11 @@ def sort_vertices_ccw(vertices):
316316
incidence = data.incidence_hyperedges.coalesce()
317317

318318
# Collect vertices
319-
vertices = [i for i in range(data.x.shape[0])]
319+
if hasattr(data, "x") and data.x is not None:
320+
vertices = [i for i in range(data.x.shape[0])]
321+
else:
322+
vertices = [i for i in range(data["x_0"].shape[0])]
323+
320324
# Hyperedges
321325
if max_order == 0:
322326
n_vertices = len(vertices)
@@ -679,3 +683,8 @@ def plot_pointcloud_voronoi(
679683
ax.scatter(*points.T, s=1.0, c=color, cmap=cm.flag)
680684
ax.view_init(elev=10.0, azim=azim, roll=roll)
681685
plt.show()
686+
687+
688+
def edge_cycle_to_vertex_cycle(edge_cycle: list[list | tuple]):
689+
"""Takes a cycle represented by a list of edges and returns a vertex representation: [(1, 2), (0, 1), (1, 2)] -> [1, 2, 3]."""
690+
return [e[0] for e in nx.find_cycle(nx.Graph(edge_cycle))]

0 commit comments

Comments
 (0)