Skip to content

Commit a836af2

Browse files
committed
Merge branch 'coface_cc_lift' of github.com:martin-carrasco/challenge-icml-2024
2 parents 74f1281 + 11cad90 commit a836af2

File tree

10 files changed

+710
-2
lines changed

10 files changed

+710
-2
lines changed

configs/datasets/KarateClub.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
data_name: KarateClub
2+
data_domain: simplex
3+
data_dir: datasets/${data_domain}/${data_type}
4+
data_type: simplex
5+
6+
num_features: 2
7+
num_classes: 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
transform_type: 'lifting'
2+
transform_name: "CofaceCCLifting"
3+
feature_lifting: ProjectionSum
4+
keep_features: False

modules/data/load/loaders.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,14 @@ def load(
223223
torch_geometric.data.Dataset
224224
torch_geometric.data.Dataset object containing the loaded data.
225225
"""
226-
return load_simplicial_dataset(self.parameters)
226+
root_folder = rootutils.find_root()
227+
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
228+
229+
self.data_dir = os.path.join(
230+
root_data_dir, self.parameters["data_name"]
231+
)
232+
data = load_simplicial_dataset(self.parameters)
233+
return CustomDataset([data], self.data_dir)
227234

228235

229236
class HypergraphLoader(AbstractLoader):

modules/data/utils/utils.py

+125
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
fetch_spiral_2d,
2626
)
2727
from topomodelx.utils.sparse import from_sparse
28+
from toponetx.classes.combinatorial_complex import CombinatorialComplex
2829
from torch_geometric.data import Data
2930
from torch_geometric.datasets import GeometricShapes
3031
from torch_sparse import SparseTensor, coalesce
@@ -71,6 +72,90 @@ def get_ccc_connectivity(complex, max_rank):
7172
return connectivity
7273

7374

75+
def get_combinatorial_complex_connectivity_2(
76+
complex: CombinatorialComplex, max_rank, signed=False
77+
):
78+
r"""Gets the connectivity matrices for the Combinatorial Complex.
79+
80+
Parameters
81+
----------
82+
complex : topnetx.CombinatorialComplex
83+
Cell complex.
84+
max_rank : int
85+
Maximum rank of the complex.
86+
signed : bool
87+
If True, returns signed connectivity matrices.
88+
89+
Returns
90+
-------
91+
dict
92+
Dictionary containing the connectivity matrices.
93+
"""
94+
practical_shape = list(
95+
np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape)))
96+
)
97+
connectivity = {}
98+
for rank_idx in range(max_rank + 1):
99+
for connectivity_info in [
100+
"incidence",
101+
"laplacian",
102+
"adjacency",
103+
]:
104+
try:
105+
if connectivity_info == "laplacian":
106+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
107+
from_sparse(complex.laplacian_matrix(rank=rank_idx))
108+
)
109+
elif connectivity_info == "adjacency":
110+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
111+
from_sparse(
112+
getattr(complex, f"{connectivity_info}_matrix")(
113+
rank_idx, rank_idx + 1
114+
)
115+
)
116+
)
117+
else: # incidence
118+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
119+
from_sparse(
120+
getattr(complex, f"{connectivity_info}_matrix")(
121+
rank_idx - 1, rank_idx
122+
)
123+
)
124+
)
125+
except ValueError: # noqa: PERF203
126+
if connectivity_info == "incidence":
127+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
128+
generate_zero_sparse_connectivity(
129+
m=practical_shape[rank_idx - 1],
130+
n=practical_shape[rank_idx],
131+
)
132+
)
133+
else:
134+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
135+
generate_zero_sparse_connectivity(
136+
m=practical_shape[rank_idx],
137+
n=practical_shape[rank_idx],
138+
)
139+
)
140+
except AttributeError:
141+
if connectivity_info == "incidence":
142+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
143+
generate_zero_sparse_connectivity(
144+
m=practical_shape[rank_idx - 1],
145+
n=practical_shape[rank_idx],
146+
)
147+
)
148+
else:
149+
connectivity[f"{connectivity_info}_{rank_idx}"] = (
150+
generate_zero_sparse_connectivity(
151+
m=practical_shape[rank_idx],
152+
n=practical_shape[rank_idx],
153+
)
154+
)
155+
connectivity["shape"] = practical_shape
156+
return connectivity
157+
158+
74159
def get_complex_connectivity(complex, max_rank, signed=False):
75160
r"""Gets the connectivity matrices for the complex.
76161
@@ -474,6 +559,46 @@ def load_point_cloud(
474559
return torch_geometric.data.Data(x=features, y=classes, pos=points)
475560

476561

562+
def load_manual_simplicial_complex():
563+
"""Create a manual simplicial complex for testing purposes."""
564+
num_feats = 2
565+
one_cells = [i for i in range(5)]
566+
two_cells = [[0, 1], [0, 2], [1, 2], [1, 3], [2, 3], [0, 4], [2, 4]]
567+
three_cells = [[0, 1, 2], [1, 2, 3], [0, 2, 4]]
568+
incidence_1 = [
569+
[1, 1, 0, 0, 0, 1, 0],
570+
[1, 0, 1, 1, 0, 0, 0],
571+
[0, 1, 1, 0, 1, 0, 1],
572+
[0, 0, 0, 1, 1, 0, 0],
573+
[0, 0, 0, 0, 0, 1, 1],
574+
]
575+
incidence_2 = [
576+
[1, 0, 0],
577+
[1, 0, 1],
578+
[1, 1, 0],
579+
[0, 1, 0],
580+
[0, 1, 0],
581+
[0, 0, 1],
582+
[0, 0, 1],
583+
]
584+
585+
y = [1]
586+
587+
return torch_geometric.data.Data(
588+
x_0=torch.rand(len(one_cells), num_feats),
589+
x_1=torch.rand(len(two_cells), num_feats),
590+
x_2=torch.rand(len(three_cells), num_feats),
591+
incidence_0=torch.zeros((1, 5)).to_sparse(),
592+
adjacency_1=torch.zeros((len(one_cells), len(one_cells))).to_sparse(),
593+
adjacency_2=torch.zeros((len(two_cells), len(two_cells))).to_sparse(),
594+
adjacency_0=torch.zeros((5, 5)).to_sparse(),
595+
incidence_1=torch.tensor(incidence_1).to_sparse(),
596+
incidence_2=torch.tensor(incidence_2).to_sparse(),
597+
num_nodes=len(one_cells),
598+
y=torch.tensor(y),
599+
)
600+
601+
477602
def load_manual_graph():
478603
"""Create a manual graph for testing purposes."""
479604
# Define the vertices (just 8 vertices)

modules/models/combinatorial/hmc.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def __init__(self, model_config, dataset_config):
2020
if isinstance(dataset_config["num_features"], int)
2121
else dataset_config["num_features"][0]
2222
)
23-
2423
negative_slope = model_config["negative_slope"]
2524
hidden_channels = model_config["hidden_channels"]
2625
out_channels = dataset_config["num_classes"]

modules/transforms/data_transform.py

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383
from modules.transforms.liftings.pointcloud2simplicial.random_flag_complex import (
8484
RandomFlagComplexLifting,
8585
)
86+
from modules.transforms.liftings.simplicial2combinatorial.coface_cc_lifting import (
87+
CofaceCCLifting,
88+
)
8689

8790
TRANSFORMS = {
8891
# Graph -> Hypergraph
@@ -119,6 +122,8 @@
119122
"UniversalStrictLifting": UniversalStrictLifting,
120123
# Hypergraph -> Simplicial Complex
121124
"HypergraphHeatLifting": HypergraphHeatLifting,
125+
# Simplicial Complex -> Combinatorial Complex
126+
"CofaceCCLifting": CofaceCCLifting,
122127
# Feature Liftings
123128
"ProjectionSum": ProjectionSum,
124129
# Data Manipulations
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from modules.transforms.liftings.lifting import SimplicialLifting
2+
3+
4+
class Simplicial2CombinatorialLifting(SimplicialLifting):
5+
r"""Abstract class for lifting graphs to combinatorial complexes.
6+
7+
Parameters
8+
----------
9+
**kwargs : optiona""l
10+
Additional arguments for the class.
11+
"""
12+
13+
def __init__(self, **kwargs):
14+
super().__init__(**kwargs)
15+
self.type = "simplicial2combinatorial"
16+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from toponetx.classes.combinatorial_complex import CombinatorialComplex
2+
from toponetx.classes.hyperedge import HyperEdge
3+
from torch_geometric.data import Data
4+
5+
from modules.data.utils.utils import get_combinatorial_complex_connectivity
6+
from modules.transforms.liftings.simplicial2combinatorial.base import (
7+
Simplicial2CombinatorialLifting,
8+
)
9+
10+
11+
class CofaceCCLifting(Simplicial2CombinatorialLifting):
12+
def __init__(self, **kwargs):
13+
super().__init__(**kwargs)
14+
self.keep_features = kwargs.get("keep_features", False)
15+
16+
def get_lower_cells(self, data: Data) -> list[HyperEdge]:
17+
""" Get the lower cells of the complex
18+
19+
Parameters:
20+
data (Data): The input data
21+
Returns:
22+
List[HyperEdge]: The lower cells of the complex
23+
"""
24+
cells: list[HyperEdge] = []
25+
26+
## Add 0-cells
27+
for cell in range(data["x_0"].size(0)):
28+
zero_cell = HyperEdge([cell], rank=0)
29+
cells.append(zero_cell)
30+
31+
## Add 1-cells
32+
for inc_c_1 in data["incidence_1"].to_dense().T:
33+
# Get the 0-cells that are incident to the 1-cell
34+
cell_0_bound = inc_c_1.nonzero().flatten().tolist()
35+
assert(len(cell_0_bound) == 2)
36+
one_cell = HyperEdge(cell_0_bound, rank=1)
37+
cells.append(one_cell)
38+
39+
## Add 2-cells
40+
for inc_c_2 in data["incidence_2"].to_dense().T:
41+
# Get the 1-cells that are incident to the 2-cell
42+
cell_1_bound = inc_c_2.nonzero().flatten()
43+
# Get the 0-cells that are incident to the 1-cells
44+
cell_0_bound = data["incidence_1"].to_dense().T[cell_1_bound].nonzero()
45+
# Get the actual 0-cells since nonzero()
46+
# indexes in 2D
47+
cell_0_bound = cell_0_bound[:, 1]
48+
# Remove redudants and convert to tuple
49+
two_cell = HyperEdge(tuple(set(cell_0_bound.tolist())), rank=2)
50+
cells.append(two_cell)
51+
52+
return cells
53+
54+
def lift_topology(self, data: Data) -> dict:
55+
""" Lift the simplicial topology to a combinatorial complex
56+
"""
57+
58+
# Check that the dataset has the required fields
59+
# assume that it's a simplicial dataset
60+
assert "incidence_1" in data
61+
assert "incidence_2" in data
62+
63+
cells = self.get_lower_cells(data)
64+
65+
ccc = CombinatorialComplex(cells, graph_based=False)
66+
67+
# Iterate over the 2-cells and add the 3-cells
68+
for r_cell in ccc.skeleton(rank=2):
69+
# Get the coface of the 2-cell
70+
indices, coface = ccc.coadjacency_matrix(2, 1, index=True)
71+
72+
# Get the indices of the 2-cell that are co-adjacent
73+
coface_indices = coface.todense()[indices[r_cell]].nonzero()[1].tolist()
74+
cell_3 = set(r_cell)
75+
76+
# Iterate over the indices of the 2-cells
77+
# and add their 0-cells as a 3-cell
78+
for idx in coface_indices:
79+
cell_3 = cell_3.union(set(ccc.skeleton(rank=2)[idx]))
80+
81+
# Adding a rank 3 cell with less than 4 vertices
82+
# will take this cell from the skeleton of 2-cells if it exists
83+
# so in the interest of keeping features the user
84+
# can choose to recompute all feature embeddings
85+
if len(cell_3) < 4 and self.keep_features:
86+
continue
87+
# Get the cofaces incident to the 2-cell `cell` and add `cell` to the set
88+
ccc.add_cell(cell_3, rank=3)
89+
90+
# Create the incidence, adjacency and laplacian matrices
91+
lifted_data = get_combinatorial_complex_connectivity(ccc, 3)
92+
93+
# If the user wants to keep the features
94+
# from the r-cells aside from the first x_0
95+
if self.keep_features:
96+
lifted_data = {"x_0": data["x_0"], "x_1": data["x_1"], "x_2": data["x_2"], **lifted_data}
97+
else:
98+
lifted_data = {"x_0": data["x_0"], **lifted_data}
99+
100+
return lifted_data
101+
102+
def forward(self, data: Data) -> Data:
103+
initial_data = data.to_dict()
104+
lifted_topology = self.lift_topology(data)
105+
lifted_topology = self.feature_lifting(lifted_topology)
106+
107+
# Make sure to remove passing of duplicated data
108+
# so that the constructor of Data does not raise an error
109+
110+
for k in lifted_topology:
111+
if k in initial_data:
112+
del initial_data[k]
113+
return Data(**initial_data, **lifted_topology)

0 commit comments

Comments
 (0)