|
1 | 1 | import hashlib
|
| 2 | +import itertools as it |
2 | 3 | import os
|
3 | 4 | import os.path as osp
|
4 | 5 | import pickle
|
| 6 | +import tempfile |
| 7 | +import zipfile |
5 | 8 | from collections.abc import Callable
|
6 | 9 | from urllib.request import urlretrieve
|
7 | 10 |
|
|
14 | 17 | import torch_geometric
|
15 | 18 | import torch_geometric.data
|
16 | 19 | import torch_geometric.transforms as T
|
| 20 | +import torch_sparse |
17 | 21 | from gudhi.datasets.generators import points
|
18 | 22 | from gudhi.datasets.remote import (
|
19 | 23 | fetch_bunny,
|
|
23 | 27 | from topomodelx.utils.sparse import from_sparse
|
24 | 28 | from torch_geometric.data import Data
|
25 | 29 | from torch_geometric.datasets import GeometricShapes
|
26 |
| -from torch_sparse import coalesce |
| 30 | +from torch_sparse import SparseTensor, coalesce |
27 | 31 |
|
28 | 32 | rootutils.setup_root("./", indicator=".project-root", pythonpath=True)
|
29 | 33 |
|
@@ -797,6 +801,178 @@ def load_manual_hypergraph():
|
797 | 801 | )
|
798 | 802 |
|
799 | 803 |
|
| 804 | +def load_manual_hypergraph_2(cfg: dict): |
| 805 | + """Create a manual hypergraph for testing purposes.""" |
| 806 | + rng = np.random.default_rng(1234) |
| 807 | + n, m = 12, 24 |
| 808 | + hyperedges = set( |
| 809 | + [tuple(np.flatnonzero(rng.choice([0, 1], size=n))) for _ in range(m)] |
| 810 | + ) |
| 811 | + hyperedges = [np.array(he) for he in hyperedges] |
| 812 | + R = torch.tensor(np.concatenate(hyperedges), dtype=torch.long) |
| 813 | + C = torch.tensor( |
| 814 | + np.repeat(np.arange(len(hyperedges)), [len(he) for he in hyperedges]), |
| 815 | + dtype=torch.long, |
| 816 | + ) |
| 817 | + V = torch.tensor(np.ones(len(R))) |
| 818 | + incidence_hyperedges = torch_sparse.SparseTensor(row=R, col=C, value=V) |
| 819 | + incidence_hyperedges = ( |
| 820 | + incidence_hyperedges.coalesce().to_torch_sparse_coo_tensor() |
| 821 | + ) |
| 822 | + |
| 823 | + ## Bipartite graph repr. |
| 824 | + edges = np.array( |
| 825 | + list( |
| 826 | + it.chain( |
| 827 | + *[[(i, v) for v in he] for i, he in enumerate(hyperedges)] |
| 828 | + ) |
| 829 | + ) |
| 830 | + ) |
| 831 | + return Data( |
| 832 | + x=torch.empty((n, 0)), |
| 833 | + edge_index=torch.tensor(edges, dtype=torch.long), |
| 834 | + num_nodes=n, |
| 835 | + num_node_features=0, |
| 836 | + num_edges=len(hyperedges), |
| 837 | + incidence_hyperedges=incidence_hyperedges, |
| 838 | + max_dim=cfg.get("max_dim", 3), |
| 839 | + ) |
| 840 | + |
| 841 | + |
| 842 | +def load_contact_primary_school(cfg: dict, data_dir: str): |
| 843 | + import gdown |
| 844 | + |
| 845 | + url = "https://drive.google.com/uc?id=1H7PGDPvjCyxbogUqw17YgzMc_GHLjbZA" |
| 846 | + fn = tempfile.NamedTemporaryFile() |
| 847 | + gdown.download(url, fn.name, quiet=False) |
| 848 | + archive = zipfile.ZipFile(fn.name, "r") |
| 849 | + labels = archive.open( |
| 850 | + "contact-primary-school/node-labels-contact-primary-school.txt", "r" |
| 851 | + ).readlines() |
| 852 | + hyperedges = archive.open( |
| 853 | + "contact-primary-school/hyperedges-contact-primary-school.txt", "r" |
| 854 | + ).readlines() |
| 855 | + label_names = archive.open( |
| 856 | + "contact-primary-school/label-names-contact-primary-school.txt", "r" |
| 857 | + ).readlines() |
| 858 | + |
| 859 | + hyperedges = [ |
| 860 | + list(map(int, he.decode().replace("\n", "").strip().split(","))) |
| 861 | + for he in hyperedges |
| 862 | + ] |
| 863 | + labels = np.array( |
| 864 | + [int(b.decode().replace("\n", "").strip()) for b in labels] |
| 865 | + ) |
| 866 | + label_names = np.array( |
| 867 | + [b.decode().replace("\n", "").strip() for b in label_names] |
| 868 | + ) |
| 869 | + |
| 870 | + # Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html |
| 871 | + HE_coo = torch.tensor( |
| 872 | + np.array( |
| 873 | + [ |
| 874 | + np.concatenate(hyperedges), |
| 875 | + np.repeat( |
| 876 | + np.arange(len(hyperedges)), [len(he) for he in hyperedges] |
| 877 | + ), |
| 878 | + ] |
| 879 | + ) |
| 880 | + ) |
| 881 | + |
| 882 | + incidence_hyperedges = ( |
| 883 | + SparseTensor( |
| 884 | + row=HE_coo[0, :], |
| 885 | + col=HE_coo[1, :], |
| 886 | + value=torch.tensor(np.ones(HE_coo.shape[1])), |
| 887 | + ) |
| 888 | + .coalesce() |
| 889 | + .to_torch_sparse_coo_tensor() |
| 890 | + ) |
| 891 | + |
| 892 | + return Data( |
| 893 | + x=torch.empty((len(labels), 0)), |
| 894 | + edge_index=HE_coo, |
| 895 | + y=torch.LongTensor(labels), |
| 896 | + y_names=label_names, |
| 897 | + num_nodes=len(labels), |
| 898 | + num_node_features=0, |
| 899 | + num_edges=len(hyperedges), |
| 900 | + incidence_hyperedges=incidence_hyperedges, |
| 901 | + max_dim=cfg.get("max_dim", 1), |
| 902 | + # x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0))) |
| 903 | + ) |
| 904 | + |
| 905 | + |
| 906 | +def load_senate_committee( |
| 907 | + cfg: dict, data_dir: str |
| 908 | +) -> torch_geometric.data.Data: |
| 909 | + import tempfile |
| 910 | + import zipfile |
| 911 | + |
| 912 | + import gdown |
| 913 | + |
| 914 | + url = "https://drive.google.com/uc?id=17ZRVwki_x_C_DlOAea5dPBO7Q4SRTRRw" |
| 915 | + fn = tempfile.NamedTemporaryFile() |
| 916 | + gdown.download(url, fn.name, quiet=False) |
| 917 | + archive = zipfile.ZipFile(fn.name, "r") |
| 918 | + labels = archive.open( |
| 919 | + "senate-committees/node-labels-senate-committees.txt", "r" |
| 920 | + ).readlines() |
| 921 | + hyperedges = archive.open( |
| 922 | + "senate-committees/hyperedges-senate-committees.txt", "r" |
| 923 | + ).readlines() |
| 924 | + label_names = archive.open( |
| 925 | + "senate-committees/node-names-senate-committees.txt", "r" |
| 926 | + ).readlines() |
| 927 | + |
| 928 | + hyperedges = [ |
| 929 | + list(map(int, he.decode().replace("\n", "").strip().split(","))) |
| 930 | + for he in hyperedges |
| 931 | + ] |
| 932 | + labels = np.array( |
| 933 | + [int(b.decode().replace("\n", "").strip()) for b in labels] |
| 934 | + ) |
| 935 | + label_names = np.array( |
| 936 | + [b.decode().replace("\n", "").strip() for b in label_names] |
| 937 | + ) |
| 938 | + |
| 939 | + # Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html |
| 940 | + HE_coo = torch.tensor( |
| 941 | + np.array( |
| 942 | + [ |
| 943 | + np.concatenate(hyperedges) - 1, |
| 944 | + np.repeat( |
| 945 | + np.arange(len(hyperedges)), [len(he) for he in hyperedges] |
| 946 | + ), |
| 947 | + ] |
| 948 | + ) |
| 949 | + ) |
| 950 | + from torch_sparse import SparseTensor |
| 951 | + |
| 952 | + incidence_hyperedges = ( |
| 953 | + SparseTensor( |
| 954 | + row=HE_coo[0, :], |
| 955 | + col=HE_coo[1, :], |
| 956 | + value=torch.tensor(np.ones(HE_coo.shape[1])), |
| 957 | + ) |
| 958 | + .coalesce() |
| 959 | + .to_torch_sparse_coo_tensor() |
| 960 | + ) |
| 961 | + |
| 962 | + return Data( |
| 963 | + x=torch.empty((len(labels), 0)), |
| 964 | + edge_index=HE_coo, |
| 965 | + y=torch.LongTensor(labels), |
| 966 | + y_names=label_names, |
| 967 | + num_nodes=len(labels), |
| 968 | + num_node_features=0, |
| 969 | + num_edges=len(hyperedges), |
| 970 | + incidence_hyperedges=incidence_hyperedges, |
| 971 | + max_dim=cfg.get("max_dim", 2), |
| 972 | + # x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0))) |
| 973 | + ) |
| 974 | + |
| 975 | + |
800 | 976 | def get_Planetoid_pyg(cfg):
|
801 | 977 | r"""Loads Planetoid graph datasets from torch_geometric.
|
802 | 978 |
|
|
0 commit comments