Skip to content

Commit cf296c7

Browse files
committed
Merge branch 'peekxc-main'
2 parents e0f2c36 + 7132058 commit cf296c7

File tree

15 files changed

+1057
-14
lines changed

15 files changed

+1057
-14
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
data_domain: hypergraph
2+
data_type: contact
3+
data_name: ContactPrimarySchool
4+
data_dir: datasets/${data_domain}/${data_type}
5+
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}
6+
7+
# Dataset parameters
8+
num_nodes: 242
9+
num_hyperedges: 12704
10+
num_classes: 11
11+
max_dim: 1
12+
task: classification
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
data_domain: hypergraph
2+
data_type: toy_hypergraph
3+
data_name: manual_hg
4+
data_dir: datasets/${data_domain}/${data_type}
5+
6+
num_nodes: 12
7+
num_hyperedges: 24
8+
max_dim: 2
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
data_domain: hypergraph
2+
data_type: interaction
3+
data_name: senate_committee
4+
data_dir: datasets/${data_domain}/${data_type}
5+
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}
6+
7+
# Dataset parameters
8+
num_nodes: 282
9+
num_hyperedges: 315
10+
num_classes: 2
11+
max_dim: 2
12+
task: classification
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
transform_type: 'lifting'
2+
transform_name: "HypergraphHeatLifting"
3+
complex_dim: 2
4+
signed: True
5+
feature_lifting: ProjectionSum

modules/data/load/loaders.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818
from modules.data.utils.utils import (
1919
load_8_vertex_cubic_graphs,
2020
load_cell_complex_dataset,
21+
load_contact_primary_school,
2122
load_gudhi_dataset,
2223
load_hypergraph_pickle_dataset,
2324
load_manual_graph,
2425
load_manual_hypergraph,
26+
load_manual_hypergraph_2,
2527
load_manual_mol,
2628
load_manual_points,
2729
load_point_cloud,
2830
load_pointcloud_dataset,
2931
load_random_points,
3032
load_random_shape_point_cloud,
33+
load_senate_committee,
3134
load_simplicial_dataset,
3235
)
3336

@@ -247,18 +250,23 @@ def load(
247250
torch_geometric.data.Dataset object containing the loaded data.
248251
"""
249252
# Manual hypergraph
253+
root_folder = rootutils.find_root()
254+
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
255+
self.data_dir = os.path.join(
256+
root_data_dir, self.parameters["data_name"]
257+
)
250258
if self.parameters.data_name in ["manual"]:
251-
root_folder = rootutils.find_root()
252-
root_data_dir = os.path.join(
253-
root_folder, self.parameters["data_dir"]
254-
)
255-
self.data_dir = os.path.join(
256-
root_data_dir, self.parameters["data_name"]
257-
)
258-
259259
data = load_manual_hypergraph()
260260
return CustomDataset([data], self.parameters.data_dir)
261-
261+
if self.parameters.data_name in ["ContactPrimarySchool"]:
262+
data = load_contact_primary_school(self.parameters, self.data_dir)
263+
return CustomDataset([data], self.data_dir)
264+
if self.parameters.data_name in ["senate_committee"]:
265+
data = load_senate_committee(self.parameters, self.data_dir)
266+
return CustomDataset([data], self.data_dir)
267+
if self.parameters.data_name in ["manual_hg"]:
268+
data = load_manual_hypergraph_2(self.parameters)
269+
return CustomDataset([data], self.data_dir)
262270
return load_hypergraph_pickle_dataset(self.parameters)
263271

264272

modules/data/preprocess/preprocessor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ class PreProcessor(torch_geometric.data.InMemoryDataset):
1212
1313
Parameters
1414
----------
15-
data_dir : str
16-
Path to the directory containing the data.
1715
data_list : list
1816
List of data objects.
1917
transforms_config : DictConfig | dict
2018
Configuration parameters for the transforms.
19+
data_dir : str
20+
Path to the directory containing the data.
2121
**kwargs: optional
2222
Additional arguments.
2323
"""

modules/data/utils/utils.py

+177-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import hashlib
2+
import itertools as it
23
import os
34
import os.path as osp
45
import pickle
6+
import tempfile
7+
import zipfile
58
from collections.abc import Callable
69
from urllib.request import urlretrieve
710

@@ -14,6 +17,7 @@
1417
import torch_geometric
1518
import torch_geometric.data
1619
import torch_geometric.transforms as T
20+
import torch_sparse
1721
from gudhi.datasets.generators import points
1822
from gudhi.datasets.remote import (
1923
fetch_bunny,
@@ -23,7 +27,7 @@
2327
from topomodelx.utils.sparse import from_sparse
2428
from torch_geometric.data import Data
2529
from torch_geometric.datasets import GeometricShapes
26-
from torch_sparse import coalesce
30+
from torch_sparse import SparseTensor, coalesce
2731

2832
rootutils.setup_root("./", indicator=".project-root", pythonpath=True)
2933

@@ -797,6 +801,178 @@ def load_manual_hypergraph():
797801
)
798802

799803

804+
def load_manual_hypergraph_2(cfg: dict):
805+
"""Create a manual hypergraph for testing purposes."""
806+
rng = np.random.default_rng(1234)
807+
n, m = 12, 24
808+
hyperedges = set(
809+
[tuple(np.flatnonzero(rng.choice([0, 1], size=n))) for _ in range(m)]
810+
)
811+
hyperedges = [np.array(he) for he in hyperedges]
812+
R = torch.tensor(np.concatenate(hyperedges), dtype=torch.long)
813+
C = torch.tensor(
814+
np.repeat(np.arange(len(hyperedges)), [len(he) for he in hyperedges]),
815+
dtype=torch.long,
816+
)
817+
V = torch.tensor(np.ones(len(R)))
818+
incidence_hyperedges = torch_sparse.SparseTensor(row=R, col=C, value=V)
819+
incidence_hyperedges = (
820+
incidence_hyperedges.coalesce().to_torch_sparse_coo_tensor()
821+
)
822+
823+
## Bipartite graph repr.
824+
edges = np.array(
825+
list(
826+
it.chain(
827+
*[[(i, v) for v in he] for i, he in enumerate(hyperedges)]
828+
)
829+
)
830+
)
831+
return Data(
832+
x=torch.empty((n, 0)),
833+
edge_index=torch.tensor(edges, dtype=torch.long),
834+
num_nodes=n,
835+
num_node_features=0,
836+
num_edges=len(hyperedges),
837+
incidence_hyperedges=incidence_hyperedges,
838+
max_dim=cfg.get("max_dim", 3),
839+
)
840+
841+
842+
def load_contact_primary_school(cfg: dict, data_dir: str):
843+
import gdown
844+
845+
url = "https://drive.google.com/uc?id=1H7PGDPvjCyxbogUqw17YgzMc_GHLjbZA"
846+
fn = tempfile.NamedTemporaryFile()
847+
gdown.download(url, fn.name, quiet=False)
848+
archive = zipfile.ZipFile(fn.name, "r")
849+
labels = archive.open(
850+
"contact-primary-school/node-labels-contact-primary-school.txt", "r"
851+
).readlines()
852+
hyperedges = archive.open(
853+
"contact-primary-school/hyperedges-contact-primary-school.txt", "r"
854+
).readlines()
855+
label_names = archive.open(
856+
"contact-primary-school/label-names-contact-primary-school.txt", "r"
857+
).readlines()
858+
859+
hyperedges = [
860+
list(map(int, he.decode().replace("\n", "").strip().split(",")))
861+
for he in hyperedges
862+
]
863+
labels = np.array(
864+
[int(b.decode().replace("\n", "").strip()) for b in labels]
865+
)
866+
label_names = np.array(
867+
[b.decode().replace("\n", "").strip() for b in label_names]
868+
)
869+
870+
# Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html
871+
HE_coo = torch.tensor(
872+
np.array(
873+
[
874+
np.concatenate(hyperedges),
875+
np.repeat(
876+
np.arange(len(hyperedges)), [len(he) for he in hyperedges]
877+
),
878+
]
879+
)
880+
)
881+
882+
incidence_hyperedges = (
883+
SparseTensor(
884+
row=HE_coo[0, :],
885+
col=HE_coo[1, :],
886+
value=torch.tensor(np.ones(HE_coo.shape[1])),
887+
)
888+
.coalesce()
889+
.to_torch_sparse_coo_tensor()
890+
)
891+
892+
return Data(
893+
x=torch.empty((len(labels), 0)),
894+
edge_index=HE_coo,
895+
y=torch.LongTensor(labels),
896+
y_names=label_names,
897+
num_nodes=len(labels),
898+
num_node_features=0,
899+
num_edges=len(hyperedges),
900+
incidence_hyperedges=incidence_hyperedges,
901+
max_dim=cfg.get("max_dim", 1),
902+
# x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0)))
903+
)
904+
905+
906+
def load_senate_committee(
907+
cfg: dict, data_dir: str
908+
) -> torch_geometric.data.Data:
909+
import tempfile
910+
import zipfile
911+
912+
import gdown
913+
914+
url = "https://drive.google.com/uc?id=17ZRVwki_x_C_DlOAea5dPBO7Q4SRTRRw"
915+
fn = tempfile.NamedTemporaryFile()
916+
gdown.download(url, fn.name, quiet=False)
917+
archive = zipfile.ZipFile(fn.name, "r")
918+
labels = archive.open(
919+
"senate-committees/node-labels-senate-committees.txt", "r"
920+
).readlines()
921+
hyperedges = archive.open(
922+
"senate-committees/hyperedges-senate-committees.txt", "r"
923+
).readlines()
924+
label_names = archive.open(
925+
"senate-committees/node-names-senate-committees.txt", "r"
926+
).readlines()
927+
928+
hyperedges = [
929+
list(map(int, he.decode().replace("\n", "").strip().split(",")))
930+
for he in hyperedges
931+
]
932+
labels = np.array(
933+
[int(b.decode().replace("\n", "").strip()) for b in labels]
934+
)
935+
label_names = np.array(
936+
[b.decode().replace("\n", "").strip() for b in label_names]
937+
)
938+
939+
# Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html
940+
HE_coo = torch.tensor(
941+
np.array(
942+
[
943+
np.concatenate(hyperedges) - 1,
944+
np.repeat(
945+
np.arange(len(hyperedges)), [len(he) for he in hyperedges]
946+
),
947+
]
948+
)
949+
)
950+
from torch_sparse import SparseTensor
951+
952+
incidence_hyperedges = (
953+
SparseTensor(
954+
row=HE_coo[0, :],
955+
col=HE_coo[1, :],
956+
value=torch.tensor(np.ones(HE_coo.shape[1])),
957+
)
958+
.coalesce()
959+
.to_torch_sparse_coo_tensor()
960+
)
961+
962+
return Data(
963+
x=torch.empty((len(labels), 0)),
964+
edge_index=HE_coo,
965+
y=torch.LongTensor(labels),
966+
y_names=label_names,
967+
num_nodes=len(labels),
968+
num_node_features=0,
969+
num_edges=len(hyperedges),
970+
incidence_hyperedges=incidence_hyperedges,
971+
max_dim=cfg.get("max_dim", 2),
972+
# x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0)))
973+
)
974+
975+
800976
def get_Planetoid_pyg(cfg):
801977
r"""Loads Planetoid graph datasets from torch_geometric.
802978

modules/transforms/data_transform.py

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
from modules.transforms.liftings.hypergraph2combinatorial.universal_strict_lifting import (
5757
UniversalStrictLifting,
5858
)
59+
from modules.transforms.liftings.hypergraph2simplicial.heat_lifting import (
60+
HypergraphHeatLifting,
61+
)
5962
from modules.transforms.liftings.pointcloud2hypergraph.mogmst_lifting import (
6063
MoGMSTLifting,
6164
)
@@ -102,6 +105,8 @@
102105
"PointNetLifting": PointNetLifting,
103106
# Hypergraph -> Combinatorial Complex
104107
"UniversalStrictLifting": UniversalStrictLifting,
108+
# Hypergraph -> Simplicial Complex
109+
"HypergraphHeatLifting": HypergraphHeatLifting,
105110
# Feature Liftings
106111
"ProjectionSum": ProjectionSum,
107112
# Data Manipulations

modules/transforms/feature_liftings/feature_liftings.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def lift_features(
2828
-------
2929
torch_geometric.data.Data | dict
3030
The lifted data."""
31-
keys = sorted([key.split("_")[1] for key in data.keys() if "incidence" in key]) # noqa : SIM118
31+
keys = sorted(
32+
[key.split("_")[1] for key in data.keys() if "incidence" in key] # noqa
33+
)
3234
for elem in keys:
3335
if f"x_{elem}" not in data:
3436
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from modules.transforms.liftings.lifting import HypergraphLifting
2+
3+
4+
class Hypergraph2SimplicialLifting(HypergraphLifting):
5+
r"""Abstract class for lifting hyper graphs to simplicial complexes.
6+
7+
Parameters
8+
----------
9+
complex_dim : int, optional
10+
The dimension of the simplicial complex to be generated. Default is 2.
11+
**kwargs : optional
12+
Additional arguments for the class.
13+
"""
14+
15+
def __init__(self, complex_dim=2, **kwargs):
16+
super().__init__(**kwargs)
17+
self.complex_dim = complex_dim
18+
self.type = "hypergraph2simplicial"
19+
self.signed = kwargs.get("signed", False)

0 commit comments

Comments
 (0)