Skip to content

Commit 74f1281

Browse files
committed
Merge branch 'martin-carrasco-feat_rsc'
2 parents e8a0c11 + e261644 commit 74f1281

File tree

5 files changed

+541
-0
lines changed

5 files changed

+541
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
transform_type: 'lifting'
2+
transform_name: "RandomFlagComplexLifting"
3+
complex_dim: 3
4+
alpha: 1.5
5+
steps: 10
6+
signed: True
7+
feature_lifting: ProjectionSum

modules/transforms/data_transform.py

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@
8080
from modules.transforms.liftings.pointcloud2simplicial.delaunay_lifting import (
8181
DelaunayLifting,
8282
)
83+
from modules.transforms.liftings.pointcloud2simplicial.random_flag_complex import (
84+
RandomFlagComplexLifting,
85+
)
8386

8487
TRANSFORMS = {
8588
# Graph -> Hypergraph
@@ -107,6 +110,7 @@
107110
"AlphaComplexLifting": AlphaComplexLifting,
108111
# Point-cloud -> Simplicial Complex
109112
"DelaunayLifting": DelaunayLifting,
113+
"RandomFlagComplexLifting": RandomFlagComplexLifting,
110114
# Pointcloud -> Hypergraph
111115
"VoronoiLifting": VoronoiLifting,
112116
"MoGMSTLifting": MoGMSTLifting,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from itertools import combinations
2+
3+
import gudhi
4+
import gudhi.simplex_tree
5+
import networkx as nx
6+
import numpy as np
7+
import torch
8+
from toponetx.classes import SimplicialComplex
9+
from torch_geometric.data import Data
10+
11+
from modules.data.utils.utils import get_complex_connectivity
12+
from modules.transforms.liftings.pointcloud2simplicial.base import (
13+
PointCloud2SimplicialLifting,
14+
)
15+
16+
17+
class RandomFlagComplexLifting(PointCloud2SimplicialLifting):
18+
""" Lifting of pointclouds to simplicial complexes using the Random Flag Complex construction.
19+
"""
20+
def __init__(self, steps, alpha: float | None = None, p: float | None = None, **kwargs):
21+
self.alpha = alpha
22+
self.steps = steps
23+
self.p = p
24+
super().__init__(**kwargs)
25+
26+
def lift_topology(self, data: Data) -> dict:
27+
# Get the number of points and generate an empty graph
28+
n = data["x"].size(0)
29+
if self.p is None:
30+
self.p = np.power(n, -self.alpha)
31+
self.p = float(self.p)
32+
33+
adj_mat = np.zeros((n, n))
34+
indices = np.tril_indices(n)
35+
36+
st = gudhi.SimplexTree()
37+
38+
generator = np.random.default_rng()
39+
# For each step, sample from random binomial distribution
40+
# for each edge appearign
41+
for _ in range(self.steps):
42+
number_of_edges = n*(n+1)//2
43+
prob = generator.binomial(1, self.p, size=number_of_edges)
44+
print(prob)
45+
tmp_mat = np.zeros((n, n))
46+
tmp_mat[indices] = prob
47+
np.logical_or(adj_mat, tmp_mat, out=adj_mat)
48+
np.fill_diagonal(adj_mat, 0)
49+
50+
# Insert all vertices
51+
for i in range(n):
52+
st.insert([i])
53+
54+
graph: nx.Graph = nx.from_numpy_matrix(adj_mat).to_undirected()
55+
56+
# Insert all edges
57+
for v, u in graph.edges:
58+
st.insert([v, u])
59+
60+
simplicial_complex = SimplicialComplex(graph)
61+
62+
# Add features to the vertices
63+
feats = {
64+
i: f
65+
for i, f in enumerate(data["x"])
66+
}
67+
68+
simplicial_complex.set_simplex_attributes(feats, name="features")
69+
70+
# Find the cliques up to the maximum dimension specified
71+
cliques = nx.find_cliques(graph)
72+
simplices = [set() for _ in range(2, self.complex_dim + 1)]
73+
74+
for clique in cliques:
75+
for i in range(2, self.complex_dim + 1):
76+
for c in combinations(clique, i + 1):
77+
simplices[i - 2].add(tuple(c))
78+
79+
80+
# Add the k-tuples as simplices
81+
for set_k_simplices in simplices:
82+
for k_simplex in set_k_simplices:
83+
st.insert(k_simplex)
84+
simplicial_complex.add_simplices_from(list(set_k_simplices))
85+
86+
return self._get_lifted_topology(simplicial_complex, st)
87+
88+
def _get_lifted_topology(self, simplicial_complex: SimplicialComplex, st: gudhi.SimplexTree) -> dict:
89+
90+
# Connectivity of the complex
91+
lifted_topology = get_complex_connectivity(
92+
simplicial_complex, self.complex_dim, signed=False
93+
)
94+
# Computing the persitence to obtain the Betti numbers
95+
st.compute_persistence(persistence_dim_max=True)
96+
97+
# Save the Betti numbers in the Data object
98+
lifted_topology["betti"] = torch.tensor(st.betti_numbers())
99+
100+
101+
lifted_topology["x_0"] = torch.stack(
102+
list(simplicial_complex.get_simplex_attributes("features", 0).values())
103+
)
104+
105+
# Add the indices of the simplices
106+
for r in range(simplicial_complex.dim):
107+
lifted_topology[f"x_idx_{r}"] = torch.tensor(simplicial_complex.skeleton(r))
108+
109+
return lifted_topology
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Test the message passing module."""
2+
3+
4+
from modules.data.utils.utils import load_manual_graph
5+
from modules.transforms.liftings.pointcloud2simplicial.random_flag_complex import (
6+
RandomFlagComplexLifting,
7+
)
8+
9+
10+
class TestRandomFlagComplexLifting:
11+
"""Test the SimplicialCliqueLifting class."""
12+
13+
def setup_method(self):
14+
# Load the graph
15+
self.data = load_manual_graph()
16+
del self.data["edge_attr"]
17+
del self.data["edge_index"]
18+
19+
self.lifting_p_0 = RandomFlagComplexLifting(steps=10, p=0)
20+
self.lifting_p_1 = RandomFlagComplexLifting(steps=1, p=1)
21+
self.lifting_hp = RandomFlagComplexLifting(steps=100, alpha=0.01)
22+
23+
24+
def test_empty(self):
25+
lifted_data = self.lifting_p_0.forward(self.data.clone())
26+
assert(lifted_data.x_1.size(0) == 0)
27+
28+
def test_not_empty(self):
29+
lifted_data = self.lifting_hp.forward(self.data.clone())
30+
assert(lifted_data.x_1.size(0) > 0)
31+
32+
def test_full_graph(self):
33+
lifted_data = self.lifting_p_1.forward(self.data.clone())
34+
possible_edges = lifted_data.num_nodes * (self.data.num_nodes - 1) / 2
35+
assert(lifted_data.x_1.size(0) == possible_edges)
36+
37+
38+
assert(lifted_data)

tutorials/pointcloud2simplicial/random_flag_complex_lifting.ipynb

+383
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)