Skip to content

Commit 1a59dd7

Browse files
committedFeb 20, 2025
Merge branch 'feat_nc_rg' of github.com:martin-carrasco/challenge-icml-2024 into martin-carrasco-feat_nc_rg
2 parents adc6ff2 + b0bad78 commit 1a59dd7

File tree

9 files changed

+727
-3
lines changed

9 files changed

+727
-3
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transform_name: "ElementwiseMean"
2+
transform_type: null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
transform_type: 'lifting'
2+
transform_name: "NeighborhoodComplexLifting"
3+
preserve_edge_attr: False
4+
signed: True
5+
feature_lifting: ElementwiseMean
6+
complex_dim: 5

‎modules/transforms/data_transform.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
NodeFeaturesToFloat,
88
OneHotDegreeFeatures,
99
)
10-
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
10+
from modules.transforms.feature_liftings.feature_liftings import (
11+
ElementwiseMean,
12+
ProjectionSum,
13+
)
1114
from modules.transforms.liftings.graph2cell.cycle_lifting import (
1215
CellCycleLifting,
1316
)
@@ -59,6 +62,9 @@
5962
from modules.transforms.liftings.graph2simplicial.line_lifting import (
6063
SimplicialLineLifting,
6164
)
65+
from modules.transforms.liftings.graph2simplicial.neighborhood_complex_lifting import (
66+
NeighborhoodComplexLifting,
67+
)
6268
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
6369
SimplicialVietorisRipsLifting,
6470
)
@@ -106,6 +112,7 @@
106112
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
107113
"LatentCliqueLifting": LatentCliqueLifting,
108114
"SimplicialDnDLifting": SimplicialDnDLifting,
115+
"NeighborhoodComplexLifting": NeighborhoodComplexLifting,
109116
# Graph -> Cell Complex
110117
"CellCycleLifting": CellCycleLifting,
111118
"DiscreteConfigurationComplexLifting": DiscreteConfigurationComplexLifting,
@@ -130,6 +137,7 @@
130137
"CofaceCCLifting": CofaceCCLifting,
131138
# Feature Liftings
132139
"ProjectionSum": ProjectionSum,
140+
"ElementwiseMean": ElementwiseMean,
133141
# Data Manipulations
134142
"Identity": IdentityTransform,
135143
"NodeDegrees": NodeDegrees,

‎modules/transforms/feature_liftings/feature_liftings.py

+102
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn.functional as F
23
import torch_geometric
34

45

@@ -56,3 +57,104 @@ def forward(
5657
The lifted data.
5758
"""
5859
return self.lift_features(data)
60+
61+
62+
class ElementwiseMean(torch_geometric.transforms.BaseTransform):
63+
r"""Lifts r-cell features to r+1-cells by taking the mean of the lower
64+
dimensional features.
65+
66+
Parameters
67+
----------
68+
**kwargs : optional
69+
Additional arguments for the class.
70+
"""
71+
72+
def __init__(self, **kwargs):
73+
super().__init__()
74+
75+
def lift_features(
76+
self, data: torch_geometric.data.Data | dict
77+
) -> torch_geometric.data.Data | dict:
78+
r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix.
79+
80+
Parameters
81+
----------
82+
data : torch_geometric.data.Data | dict
83+
The input data to be lifted.
84+
85+
Returns
86+
-------
87+
torch_geometric.data.Data | dict
88+
The lifted data."""
89+
90+
# Find the maximum dimension of the input data
91+
max_dim = max(
92+
[int(key.split("_")[-1]) for key in data if "x_idx" in key]
93+
)
94+
95+
# Create a list of all x_idx tensors
96+
x_idx_tensors = [data[f"x_idx_{i}"] for i in range(max_dim + 1)]
97+
98+
# Find the maximum sizes
99+
max_simplices = max(tensor.size(0) for tensor in x_idx_tensors)
100+
max_nodes = max(tensor.size(1) for tensor in x_idx_tensors)
101+
102+
# Pad tensors to have the same size
103+
padded_tensors = [
104+
F.pad(
105+
tensor,
106+
(
107+
0,
108+
max_nodes - tensor.size(1),
109+
0,
110+
max_simplices - tensor.size(0),
111+
),
112+
)
113+
for tensor in x_idx_tensors
114+
]
115+
116+
# Stack all x_idx tensors
117+
all_indices = torch.stack(padded_tensors)
118+
119+
# Create a mask for valid indices
120+
mask = all_indices != 0
121+
122+
# Replace 0s with a valid index (e.g., 0) to avoid indexing errors
123+
all_indices = all_indices.clamp(min=0)
124+
125+
# Get all embeddings at once
126+
all_embeddings = data["x_0"][all_indices]
127+
128+
# Apply mask to set padded embeddings to 0
129+
all_embeddings = all_embeddings * mask.unsqueeze(-1).float()
130+
131+
# Compute sum and count of non-zero elements
132+
embedding_sum = all_embeddings.sum(dim=2)
133+
count = mask.sum(dim=2).clamp(min=1) # Avoid division by zero
134+
135+
# Compute mean
136+
mean_embeddings = embedding_sum / count.unsqueeze(-1)
137+
138+
# Assign results back to data dictionary
139+
for i in range(1, max_dim + 1):
140+
original_size = x_idx_tensors[i].size(0)
141+
data[f"x_{i}"] = mean_embeddings[i, :original_size]
142+
143+
return data
144+
145+
def forward(
146+
self, data: torch_geometric.data.Data | dict
147+
) -> torch_geometric.data.Data | dict:
148+
r"""Applies the lifting to the input data.
149+
150+
Parameters
151+
----------
152+
data : torch_geometric.data.Data | dict
153+
The input data to be lifted.
154+
155+
Returns
156+
-------
157+
torch_geometric.data.Data | dict
158+
The lifted data.
159+
"""
160+
return self.lift_features(data)

‎modules/transforms/liftings/graph2simplicial/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_lifted_topology(
4747
list(simplicial_complex.get_simplex_attributes("features", 0).values())
4848
)
4949
# If new edges have been added during the lifting process, we discard the edge attributes
50-
if self.contains_edge_attr and simplicial_complex.shape[1] == (
50+
if self.preserve_edge_attr and simplicial_complex.shape[1] == (
5151
graph.number_of_edges()
5252
):
5353
lifted_topology["x_1"] = torch.stack(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import networkx as nx
2+
import torch
3+
from toponetx.classes import SimplicialComplex
4+
from torch_geometric.data import Data
5+
from torch_geometric.utils.convert import to_networkx
6+
7+
from modules.transforms.liftings.graph2simplicial.base import (
8+
Graph2SimplicialLifting,
9+
)
10+
11+
12+
class NeighborhoodComplexLifting(Graph2SimplicialLifting):
13+
"""Lifts graphs to a simplicial complex domain by identifying the neighborhood complex as k-simplices.
14+
The neighborhood complex of a node u is the set of nodes that share a neighbor with u.
15+
16+
"""
17+
18+
def __init__(self, **kwargs):
19+
super().__init__(**kwargs)
20+
21+
def lift_topology(self, data: Data) -> dict:
22+
graph: nx.Graph = to_networkx(data, to_undirected=True)
23+
simplicial_complex = SimplicialComplex(simplices=graph)
24+
25+
# For every node u
26+
for u in graph.nodes:
27+
neighbourhood_complex = set()
28+
neighbourhood_complex.add(u)
29+
# Check it's neighbours
30+
for v in graph.neighbors(u):
31+
# For every other node w != u ^ w != v
32+
for w in graph.nodes:
33+
# w == u
34+
if w == u:
35+
continue
36+
# w == v
37+
if w == v:
38+
continue
39+
40+
# w and u share v as it's neighbour
41+
if v in graph.neighbors(w):
42+
neighbourhood_complex.add(w)
43+
# Do not add 0-simplices
44+
if len(neighbourhood_complex) < 2:
45+
continue
46+
# Do not add i-simplices if the maximum dimension is lower
47+
if len(neighbourhood_complex) > self.complex_dim + 1:
48+
continue
49+
simplicial_complex.add_simplex(neighbourhood_complex)
50+
51+
feature_dict = {i: f for i, f in enumerate(data["x"])}
52+
53+
simplicial_complex.set_simplex_attributes(
54+
feature_dict, name="features"
55+
)
56+
57+
return self._get_lifted_topology(simplicial_complex, graph)
58+
59+
def _get_lifted_topology(
60+
self, simplicial_complex: SimplicialComplex, graph: nx.Graph
61+
) -> dict:
62+
data = super()._get_lifted_topology(simplicial_complex, graph)
63+
64+
for r in range(simplicial_complex.dim + 1):
65+
data[f"x_idx_{r}"] = torch.tensor(simplicial_complex.skeleton(r))
66+
67+
return data

‎modules/transforms/liftings/lifting.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
from modules.transforms.data_manipulations.manipulations import (
88
IdentityTransform,
99
)
10-
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
10+
from modules.transforms.feature_liftings.feature_liftings import (
11+
ElementwiseMean,
12+
ProjectionSum,
13+
)
1114

1215
# Implemented Feature Liftings
1316
FEATURE_LIFTINGS = {
1417
"ProjectionSum": ProjectionSum,
18+
"ElementwiseMean": ElementwiseMean,
1519
None: IdentityTransform,
1620
}
1721

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
2+
import networkx as nx
3+
import torch
4+
from torch_geometric.utils.convert import from_networkx
5+
6+
from modules.data.utils.utils import load_manual_graph
7+
from modules.transforms.liftings.graph2simplicial.neighborhood_complex_lifting import (
8+
NeighborhoodComplexLifting,
9+
)
10+
11+
12+
class TestNeighborhoodComplexLifting:
13+
"""Test the NeighborhoodComplexLifting class."""
14+
15+
def setup_method(self):
16+
# Load the graph
17+
self.data = load_manual_graph()
18+
19+
# Initialize the NeighborhoodComplexLifting class for dim=3
20+
self.lifting_signed = NeighborhoodComplexLifting(complex_dim=3, signed=True)
21+
self.lifting_unsigned = NeighborhoodComplexLifting(complex_dim=3, signed=False)
22+
self.lifting_high = NeighborhoodComplexLifting(complex_dim=7, signed=False)
23+
24+
# Intialize an empty graph for testing purpouses
25+
self.empty_graph = nx.empty_graph(10)
26+
self.empty_data = from_networkx(self.empty_graph)
27+
self.empty_data["x"] = torch.rand((10, 10))
28+
29+
# Intialize a start graph for testing
30+
self.star_graph = nx.star_graph(5)
31+
self.star_data = from_networkx(self.star_graph)
32+
self.star_data["x"] = torch.rand((6, 1))
33+
34+
# Intialize a random graph for testing purpouses
35+
self.random_graph = nx.fast_gnp_random_graph(5, 0.5)
36+
self.random_data = from_networkx(self.random_graph)
37+
self.random_data["x"] = torch.rand((5, 1))
38+
39+
40+
def has_neighbour(self, simplex_points: list[set]) -> tuple[bool, set[int]]:
41+
""" Verifies that the maximal simplices
42+
of Data representation of a simplicial complex
43+
share a neighbour.
44+
"""
45+
for simplex_point_a in simplex_points:
46+
for simplex_point_b in simplex_points:
47+
# Same point
48+
if simplex_point_a == simplex_point_b:
49+
continue
50+
# Search all nodes to check if they are c such that a and b share c as a neighbour
51+
for node in self.random_graph.nodes:
52+
# They share a neighbour
53+
if self.random_graph.has_edge(simplex_point_a.item(), node) and self.random_graph.has_edge(simplex_point_b.item(), node):
54+
return True
55+
return False
56+
57+
def test_lift_topology_random_graph(self):
58+
""" Verifies that the lifting procedure works on
59+
a random graph, that is, checks that the simplices
60+
generated share a neighbour.
61+
"""
62+
lifted_data = self.lifting_high.forward(self.random_data)
63+
# For each set of simplices
64+
r = max(int(key.split("_")[-1]) for key in list(lifted_data.keys()) if "x_idx_" in key)
65+
idx_str = f"x_idx_{r}"
66+
67+
# Go over each (max_dim)-simplex
68+
for simplex_points in lifted_data[idx_str]:
69+
share_neighbour = self.has_neighbour(simplex_points)
70+
assert share_neighbour, f"The simplex {simplex_points} does not have a common neighbour with all the nodes."
71+
72+
def test_lift_topology_star_graph(self):
73+
""" Verifies that the lifting procedure works on
74+
a small star graph, that is, checks that the simplices
75+
generated share a neighbour.
76+
"""
77+
lifted_data = self.lifting_high.forward(self.star_data)
78+
# For each set of simplices
79+
r = max(int(key.split("_")[-1]) for key in list(lifted_data.keys()) if "x_idx_" in key)
80+
idx_str = f"x_idx_{r}"
81+
82+
# Go over each (max_dim)-simplex
83+
for simplex_points in lifted_data[idx_str]:
84+
share_neighbour = self.has_neighbour(simplex_points)
85+
assert share_neighbour, f"The simplex {simplex_points} does not have a common neighbour with all the nodes."
86+
87+
88+
89+
def test_lift_topology_empty_graph(self):
90+
""" Test the lift_topology method with an empty graph.
91+
"""
92+
93+
lifted_data_signed = self.lifting_signed.forward(self.empty_data)
94+
95+
assert lifted_data_signed.incidence_1.shape[1] == 0, "Something is wrong with signed incidence_1 (nodes to edges)."
96+
97+
assert lifted_data_signed.incidence_2.shape[1] == 0, "Something is wrong with signed incidence_2 (edges to triangles)."
98+
99+
def test_lift_topology(self):
100+
"""Test the lift_topology method."""
101+
102+
# Test the lift_topology method
103+
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
104+
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())
105+
106+
expected_incidence_1 = torch.tensor(
107+
[
108+
[-1., -1., -1., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
109+
[ 1., 0., 0., 0., -1., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
110+
[ 0., 1., 0., 0., 1., 0., -1., -1., -1., -1., -1., 0., 0., 0., 0.],
111+
[ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., -1., 0., 0., 0.],
112+
[ 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
113+
[ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., -1., -1., 0.],
114+
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., -1.],
115+
[ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.]
116+
]
117+
)
118+
assert (
119+
abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense()
120+
).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)."
121+
assert (
122+
expected_incidence_1 == lifted_data_signed.incidence_1.to_dense()
123+
).all(), "Something is wrong with signed incidence_1 (nodes to edges)."
124+
125+
expected_incidence_2 = torch.tensor(
126+
[
127+
[ 0.],
128+
[ 0.],
129+
[ 0.],
130+
[ 0.],
131+
[ 0.],
132+
[ 0.],
133+
[ 0.],
134+
[ 0.],
135+
[ 0.],
136+
[ 1.],
137+
[-1.],
138+
[ 0.],
139+
[ 0.],
140+
[ 0.],
141+
[ 1.]
142+
]
143+
)
144+
145+
assert (
146+
abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense()
147+
).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)."
148+
assert (
149+
expected_incidence_2 == lifted_data_signed.incidence_2.to_dense()
150+
).all(), "Something is wrong with signed incidence_2 (edges to triangles)."

‎tutorials/graph2simplicial/neighborhood_complex_lifting.ipynb

+385
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.