Skip to content

Commit 19f7e40

Browse files
committed
Merge branch 'martin-carrasco-feat_nc_rg'
2 parents adc6ff2 + 1a59dd7 commit 19f7e40

File tree

9 files changed

+727
-3
lines changed

9 files changed

+727
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transform_name: "ElementwiseMean"
2+
transform_type: null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
transform_type: 'lifting'
2+
transform_name: "NeighborhoodComplexLifting"
3+
preserve_edge_attr: False
4+
signed: True
5+
feature_lifting: ElementwiseMean
6+
complex_dim: 5

modules/transforms/data_transform.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
NodeFeaturesToFloat,
88
OneHotDegreeFeatures,
99
)
10-
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
10+
from modules.transforms.feature_liftings.feature_liftings import (
11+
ElementwiseMean,
12+
ProjectionSum,
13+
)
1114
from modules.transforms.liftings.graph2cell.cycle_lifting import (
1215
CellCycleLifting,
1316
)
@@ -59,6 +62,9 @@
5962
from modules.transforms.liftings.graph2simplicial.line_lifting import (
6063
SimplicialLineLifting,
6164
)
65+
from modules.transforms.liftings.graph2simplicial.neighborhood_complex_lifting import (
66+
NeighborhoodComplexLifting,
67+
)
6268
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
6369
SimplicialVietorisRipsLifting,
6470
)
@@ -106,6 +112,7 @@
106112
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
107113
"LatentCliqueLifting": LatentCliqueLifting,
108114
"SimplicialDnDLifting": SimplicialDnDLifting,
115+
"NeighborhoodComplexLifting": NeighborhoodComplexLifting,
109116
# Graph -> Cell Complex
110117
"CellCycleLifting": CellCycleLifting,
111118
"DiscreteConfigurationComplexLifting": DiscreteConfigurationComplexLifting,
@@ -130,6 +137,7 @@
130137
"CofaceCCLifting": CofaceCCLifting,
131138
# Feature Liftings
132139
"ProjectionSum": ProjectionSum,
140+
"ElementwiseMean": ElementwiseMean,
133141
# Data Manipulations
134142
"Identity": IdentityTransform,
135143
"NodeDegrees": NodeDegrees,

modules/transforms/feature_liftings/feature_liftings.py

+102
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn.functional as F
23
import torch_geometric
34

45

@@ -56,3 +57,104 @@ def forward(
5657
The lifted data.
5758
"""
5859
return self.lift_features(data)
60+
61+
62+
class ElementwiseMean(torch_geometric.transforms.BaseTransform):
63+
r"""Lifts r-cell features to r+1-cells by taking the mean of the lower
64+
dimensional features.
65+
66+
Parameters
67+
----------
68+
**kwargs : optional
69+
Additional arguments for the class.
70+
"""
71+
72+
def __init__(self, **kwargs):
73+
super().__init__()
74+
75+
def lift_features(
76+
self, data: torch_geometric.data.Data | dict
77+
) -> torch_geometric.data.Data | dict:
78+
r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix.
79+
80+
Parameters
81+
----------
82+
data : torch_geometric.data.Data | dict
83+
The input data to be lifted.
84+
85+
Returns
86+
-------
87+
torch_geometric.data.Data | dict
88+
The lifted data."""
89+
90+
# Find the maximum dimension of the input data
91+
max_dim = max(
92+
[int(key.split("_")[-1]) for key in data if "x_idx" in key]
93+
)
94+
95+
# Create a list of all x_idx tensors
96+
x_idx_tensors = [data[f"x_idx_{i}"] for i in range(max_dim + 1)]
97+
98+
# Find the maximum sizes
99+
max_simplices = max(tensor.size(0) for tensor in x_idx_tensors)
100+
max_nodes = max(tensor.size(1) for tensor in x_idx_tensors)
101+
102+
# Pad tensors to have the same size
103+
padded_tensors = [
104+
F.pad(
105+
tensor,
106+
(
107+
0,
108+
max_nodes - tensor.size(1),
109+
0,
110+
max_simplices - tensor.size(0),
111+
),
112+
)
113+
for tensor in x_idx_tensors
114+
]
115+
116+
# Stack all x_idx tensors
117+
all_indices = torch.stack(padded_tensors)
118+
119+
# Create a mask for valid indices
120+
mask = all_indices != 0
121+
122+
# Replace 0s with a valid index (e.g., 0) to avoid indexing errors
123+
all_indices = all_indices.clamp(min=0)
124+
125+
# Get all embeddings at once
126+
all_embeddings = data["x_0"][all_indices]
127+
128+
# Apply mask to set padded embeddings to 0
129+
all_embeddings = all_embeddings * mask.unsqueeze(-1).float()
130+
131+
# Compute sum and count of non-zero elements
132+
embedding_sum = all_embeddings.sum(dim=2)
133+
count = mask.sum(dim=2).clamp(min=1) # Avoid division by zero
134+
135+
# Compute mean
136+
mean_embeddings = embedding_sum / count.unsqueeze(-1)
137+
138+
# Assign results back to data dictionary
139+
for i in range(1, max_dim + 1):
140+
original_size = x_idx_tensors[i].size(0)
141+
data[f"x_{i}"] = mean_embeddings[i, :original_size]
142+
143+
return data
144+
145+
def forward(
146+
self, data: torch_geometric.data.Data | dict
147+
) -> torch_geometric.data.Data | dict:
148+
r"""Applies the lifting to the input data.
149+
150+
Parameters
151+
----------
152+
data : torch_geometric.data.Data | dict
153+
The input data to be lifted.
154+
155+
Returns
156+
-------
157+
torch_geometric.data.Data | dict
158+
The lifted data.
159+
"""
160+
return self.lift_features(data)

modules/transforms/liftings/graph2simplicial/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_lifted_topology(
4747
list(simplicial_complex.get_simplex_attributes("features", 0).values())
4848
)
4949
# If new edges have been added during the lifting process, we discard the edge attributes
50-
if self.contains_edge_attr and simplicial_complex.shape[1] == (
50+
if self.preserve_edge_attr and simplicial_complex.shape[1] == (
5151
graph.number_of_edges()
5252
):
5353
lifted_topology["x_1"] = torch.stack(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import networkx as nx
2+
import torch
3+
from toponetx.classes import SimplicialComplex
4+
from torch_geometric.data import Data
5+
from torch_geometric.utils.convert import to_networkx
6+
7+
from modules.transforms.liftings.graph2simplicial.base import (
8+
Graph2SimplicialLifting,
9+
)
10+
11+
12+
class NeighborhoodComplexLifting(Graph2SimplicialLifting):
13+
"""Lifts graphs to a simplicial complex domain by identifying the neighborhood complex as k-simplices.
14+
The neighborhood complex of a node u is the set of nodes that share a neighbor with u.
15+
16+
"""
17+
18+
def __init__(self, **kwargs):
19+
super().__init__(**kwargs)
20+
21+
def lift_topology(self, data: Data) -> dict:
22+
graph: nx.Graph = to_networkx(data, to_undirected=True)
23+
simplicial_complex = SimplicialComplex(simplices=graph)
24+
25+
# For every node u
26+
for u in graph.nodes:
27+
neighbourhood_complex = set()
28+
neighbourhood_complex.add(u)
29+
# Check it's neighbours
30+
for v in graph.neighbors(u):
31+
# For every other node w != u ^ w != v
32+
for w in graph.nodes:
33+
# w == u
34+
if w == u:
35+
continue
36+
# w == v
37+
if w == v:
38+
continue
39+
40+
# w and u share v as it's neighbour
41+
if v in graph.neighbors(w):
42+
neighbourhood_complex.add(w)
43+
# Do not add 0-simplices
44+
if len(neighbourhood_complex) < 2:
45+
continue
46+
# Do not add i-simplices if the maximum dimension is lower
47+
if len(neighbourhood_complex) > self.complex_dim + 1:
48+
continue
49+
simplicial_complex.add_simplex(neighbourhood_complex)
50+
51+
feature_dict = {i: f for i, f in enumerate(data["x"])}
52+
53+
simplicial_complex.set_simplex_attributes(
54+
feature_dict, name="features"
55+
)
56+
57+
return self._get_lifted_topology(simplicial_complex, graph)
58+
59+
def _get_lifted_topology(
60+
self, simplicial_complex: SimplicialComplex, graph: nx.Graph
61+
) -> dict:
62+
data = super()._get_lifted_topology(simplicial_complex, graph)
63+
64+
for r in range(simplicial_complex.dim + 1):
65+
data[f"x_idx_{r}"] = torch.tensor(simplicial_complex.skeleton(r))
66+
67+
return data

modules/transforms/liftings/lifting.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
from modules.transforms.data_manipulations.manipulations import (
88
IdentityTransform,
99
)
10-
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
10+
from modules.transforms.feature_liftings.feature_liftings import (
11+
ElementwiseMean,
12+
ProjectionSum,
13+
)
1114

1215
# Implemented Feature Liftings
1316
FEATURE_LIFTINGS = {
1417
"ProjectionSum": ProjectionSum,
18+
"ElementwiseMean": ElementwiseMean,
1519
None: IdentityTransform,
1620
}
1721

0 commit comments

Comments
 (0)