Skip to content

Commit 8597e0e

Browse files
committed
Merge branch 'alvarolmartinez-universalstrict'
2 parents 09ea7a6 + 54d3249 commit 8597e0e

File tree

12 files changed

+936
-18
lines changed

12 files changed

+936
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
data_domain: hypergraph
2+
data_type: toy_dataset
3+
data_name: manual
4+
data_dir: datasets/${data_domain}/${data_type}
5+
6+
# Dataset parameters
7+
num_features: 1
8+
num_classes: 2
9+
task: classification
10+
loss_type: cross_entropy
11+
monitor_metric: accuracy
12+
task_level: node

configs/models/combinatorial/hmc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ in_channels: null # This will be set by the dataset
22
hidden_channels: 32
33
out_channels: null # This will be set by the dataset
44
n_layers: 2
5-
negative_slope: 0.2
5+
negative_slope: 0.2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transform_type: "lifting"
2+
transform_name: "UniversalStrictLifting"

modules/data/load/loaders.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
load_gudhi_dataset,
2222
load_hypergraph_pickle_dataset,
2323
load_manual_graph,
24+
load_manual_hypergraph,
2425
load_manual_mol,
2526
load_manual_points,
2627
load_point_cloud,
@@ -245,6 +246,19 @@ def load(
245246
torch_geometric.data.Dataset
246247
torch_geometric.data.Dataset object containing the loaded data.
247248
"""
249+
# Manual hypergraph
250+
if self.parameters.data_name in ["manual"]:
251+
root_folder = rootutils.find_root()
252+
root_data_dir = os.path.join(
253+
root_folder, self.parameters["data_dir"]
254+
)
255+
self.data_dir = os.path.join(
256+
root_data_dir, self.parameters["data_name"]
257+
)
258+
259+
data = load_manual_hypergraph()
260+
return CustomDataset([data], self.parameters.data_dir)
261+
248262
return load_hypergraph_pickle_dataset(self.parameters)
249263

250264

@@ -303,38 +317,38 @@ def load(self) -> torch_geometric.data.Dataset:
303317
root_data_dir, self.parameters["data_name"]
304318
)
305319

306-
if self.parameters.data_name.startswith("gudhi_"):
320+
if self.parameters["data_name"].startswith("gudhi_"):
307321
data = load_gudhi_dataset(
308322
self.parameters,
309323
feature_generator=self.feature_generator,
310324
target_generator=self.target_generator,
311325
)
312-
elif self.parameters.data_name == "random_points":
326+
elif self.parameters["data_name"] == "random_points":
313327
data = load_random_points(
314328
dim=self.parameters["dim"],
315329
num_classes=self.parameters["num_classes"],
316330
num_samples=self.parameters["num_samples"],
317331
)
318-
elif self.parameters.data_name == "toy_point_cloud":
332+
elif self.parameters["data_name"] == "toy_point_cloud":
319333
data = load_point_cloud(
320334
num_classes=self.parameters["num_classes"],
321335
num_samples=self.parameters["num_samples"],
322336
)
323-
elif self.parameters.data_name == "manual_points":
337+
elif self.parameters["data_name"] == "manual_points":
324338
data = load_manual_points()
325339
elif self.parameters.data_name == "stanford_bunny":
326340
self.data_dir = os.path.join(
327341
root_folder, self.parameters["data_dir"]
328342
)
329343
data = load_pointcloud_dataset(self.parameters)
330-
elif self.parameters.data_name == "shapes":
344+
elif self.parameters["data_name"] == "shapes":
331345
data = load_random_shape_point_cloud(
332346
num_points=self.parameters["num_points"],
333347
num_classes=self.parameters["num_classes"],
334348
)
335349
else:
336350
raise NotImplementedError(
337-
f"Dataset {self.parameters.data_name} not implemented"
351+
f"Dataset {self.parameters["data_name"]} not implemented"
338352
)
339353

340354
return CustomDataset([data], self.data_dir)

modules/data/utils/utils.py

+108-8
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,61 @@ def get_complex_connectivity(complex, max_rank, signed=False):
119119
return connectivity
120120

121121

122+
def get_combinatorial_complex_connectivity(complex, max_rank=None):
123+
r"""Gets the connectivity matrices for the combinatorial complex.
124+
125+
Parameters
126+
----------
127+
complex : topnetx.CombinatorialComplex
128+
Combinatorial complex.
129+
max_rank : int
130+
Maximum rank of the complex.
131+
132+
Returns
133+
-------
134+
dict
135+
Dictionary containing the connectivity matrices.
136+
"""
137+
if max_rank is None:
138+
max_rank = complex.dim
139+
practical_shape = list(
140+
np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape)))
141+
)
142+
143+
connectivity = {}
144+
145+
for rank_idx in range(max_rank + 1):
146+
if rank_idx > 0:
147+
try:
148+
connectivity[f"incidence_{rank_idx}"] = from_sparse(
149+
complex.incidence_matrix(
150+
rank=rank_idx - 1, to_rank=rank_idx
151+
)
152+
)
153+
except ValueError:
154+
connectivity[f"incidence_{rank_idx}"] = (
155+
generate_zero_sparse_connectivity(
156+
m=practical_shape[rank_idx],
157+
n=practical_shape[rank_idx],
158+
)
159+
)
160+
161+
try:
162+
connectivity[f"adjacency_{rank_idx}"] = from_sparse(
163+
complex.adjacency_matrix(rank=rank_idx, via_rank=rank_idx + 1)
164+
)
165+
except ValueError:
166+
connectivity[f"adjacency_{rank_idx}"] = (
167+
generate_zero_sparse_connectivity(
168+
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
169+
)
170+
)
171+
172+
connectivity["shape"] = practical_shape
173+
174+
return connectivity
175+
176+
122177
def generate_zero_sparse_connectivity(m, n):
123178
r"""Generates a zero sparse connectivity matrix.
124179
@@ -285,17 +340,13 @@ def load_hypergraph_pickle_dataset(cfg):
285340

286341
print(f"number of hyperedges: {len(hypergraph)}")
287342

288-
edge_idx = 0 # num_nodes
289343
node_list = []
290344
edge_list = []
291-
for he in hypergraph:
292-
cur_he = hypergraph[he]
293-
cur_size = len(cur_he)
294-
295-
node_list += list(cur_he)
296-
edge_list += [edge_idx] * cur_size
297345

298-
edge_idx += 1
346+
for edge_idx, cur_he in enumerate(hypergraph.values()):
347+
cur_size = len(cur_he)
348+
node_list.extend(cur_he)
349+
edge_list.extend([edge_idx] * cur_size)
299350

300351
# check that every node is in some hyperedge
301352
if len(np.unique(node_list)) != num_nodes:
@@ -641,6 +692,55 @@ def load_manual_mol():
641692
)
642693

643694

695+
def load_manual_hypergraph():
696+
"""Create a manual hypergraph for testing purposes."""
697+
# Define the vertices (just 8 vertices)
698+
vertices = [i for i in range(8)]
699+
y = [0, 1, 1, 1, 0, 0, 0, 0]
700+
# Define the hyperedges
701+
hyperedges = [
702+
[0, 1, 2, 3],
703+
[4, 5, 6, 7],
704+
[0, 1, 2],
705+
[0, 1, 3],
706+
[0, 2, 3],
707+
[1, 2, 3],
708+
[3, 4],
709+
[4, 5],
710+
[4, 7],
711+
[5, 6],
712+
[6, 7],
713+
]
714+
715+
# Generate feature from 0 to 7
716+
x = torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000]).unsqueeze(1).float()
717+
labels = torch.tensor(y, dtype=torch.long)
718+
719+
node_list = []
720+
edge_list = []
721+
722+
for edge_idx, he in enumerate(hyperedges):
723+
cur_size = len(he)
724+
node_list += he
725+
edge_list += [edge_idx] * cur_size
726+
727+
edge_index = np.array([node_list, edge_list], dtype=int)
728+
edge_index = torch.LongTensor(edge_index)
729+
730+
incidence_hyperedges = torch.sparse_coo_tensor(
731+
edge_index,
732+
values=torch.ones(edge_index.shape[1]),
733+
size=(len(vertices), len(hyperedges)),
734+
)
735+
736+
return Data(
737+
x=x,
738+
edge_index=edge_index,
739+
y=labels,
740+
incidence_hyperedges=incidence_hyperedges,
741+
)
742+
743+
644744
def get_Planetoid_pyg(cfg):
645745
r"""Loads Planetoid graph datasets from torch_geometric.
646746

modules/models/combinatorial/hmc.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, model_config, dataset_config):
3333
int_channels_l = []
3434
out_channels_l = []
3535

36-
for _ in range(3): # only 3 ranks
36+
for _ in range(3): # only 3 ranks
3737
# First layer behavior
3838
if layer == 0:
3939
in_channels_l.append(in_channels)
@@ -42,11 +42,13 @@ def __init__(self, model_config, dataset_config):
4242
int_channels_l.append(hidden_channels)
4343
out_channels_l.append(hidden_channels)
4444

45-
channels_per_layer.append((in_channels_l, int_channels_l, out_channels_l))
45+
channels_per_layer.append(
46+
(in_channels_l, int_channels_l, out_channels_l)
47+
)
4648

4749
self.base_model = HMC(
4850
channels_per_layer=channels_per_layer,
49-
negative_slope=negative_slope
51+
negative_slope=negative_slope,
5052
)
5153
self.linear_0 = torch.nn.Linear(hidden_channels, out_channels)
5254
self.linear_1 = torch.nn.Linear(hidden_channels, out_channels)

modules/transforms/data_transform.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
4545
SimplicialVietorisRipsLifting,
4646
)
47+
from modules.transforms.liftings.hypergraph2combinatorial.universal_strict_lifting import (
48+
UniversalStrictLifting,
49+
)
4750
from modules.transforms.liftings.pointcloud2hypergraph.mogmst_lifting import (
4851
MoGMSTLifting,
4952
)
@@ -81,6 +84,8 @@
8184
# Pointcloud -> Hypergraph
8285
"VoronoiLifting": VoronoiLifting,
8386
"MoGMSTLifting": MoGMSTLifting,
87+
# Hypergraph -> Combinatorial Complex
88+
"UniversalStrictLifting": UniversalStrictLifting,
8489
# Feature Liftings
8590
"ProjectionSum": ProjectionSum,
8691
# Data Manipulations
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
from toponetx import CombinatorialComplex
3+
4+
from modules.data.utils.utils import get_combinatorial_complex_connectivity
5+
from modules.transforms.liftings.lifting import HypergraphLifting
6+
7+
8+
class Hypergraph2CombinatorialLifting(HypergraphLifting):
9+
r"""Abstract class for lifting hypergraphs to combinatorial complexes.
10+
11+
Parameters
12+
----------
13+
**kwargs : optional
14+
Additional arguments for the class.
15+
"""
16+
17+
def __init__(self, **kwargs):
18+
super().__init__(**kwargs)
19+
self.type = "hypergraph2combinatorial"
20+
21+
def _get_lifted_topology(self, combinatorial_complex: CombinatorialComplex) -> dict:
22+
r"""Returns the lifted topology.
23+
24+
Parameters
25+
----------
26+
combinatorial_complex : CombinatorialComplex
27+
The combinatorial complex.
28+
29+
Returns
30+
-------
31+
dict
32+
The lifted topology.
33+
"""
34+
lifted_topology = get_combinatorial_complex_connectivity(combinatorial_complex)
35+
36+
# Feature liftings
37+
38+
features = combinatorial_complex.get_cell_attributes("features")
39+
40+
for i in range(combinatorial_complex.dim + 1):
41+
x = [
42+
feat
43+
for cell, feat in features
44+
if combinatorial_complex.cells.get_rank(cell) == i
45+
]
46+
if x:
47+
lifted_topology[f"x_{i}"] = torch.stack(x)
48+
else:
49+
num_cells = len(combinatorial_complex.skeleton(i))
50+
lifted_topology[f"x_{i}"] = torch.zeros(num_cells, 1)
51+
52+
return lifted_topology

0 commit comments

Comments
 (0)