|
25 | 25 | fetch_spiral_2d,
|
26 | 26 | )
|
27 | 27 | from topomodelx.utils.sparse import from_sparse
|
| 28 | +from toponetx.classes.combinatorial_complex import CombinatorialComplex |
28 | 29 | from torch_geometric.data import Data
|
29 | 30 | from torch_geometric.datasets import GeometricShapes
|
30 | 31 | from torch_sparse import SparseTensor, coalesce
|
@@ -71,6 +72,90 @@ def get_ccc_connectivity(complex, max_rank):
|
71 | 72 | return connectivity
|
72 | 73 |
|
73 | 74 |
|
| 75 | +def get_combinatorial_complex_connectivity_2( |
| 76 | + complex: CombinatorialComplex, max_rank, signed=False |
| 77 | +): |
| 78 | + r"""Gets the connectivity matrices for the Combinatorial Complex. |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + complex : topnetx.CombinatorialComplex |
| 83 | + Cell complex. |
| 84 | + max_rank : int |
| 85 | + Maximum rank of the complex. |
| 86 | + signed : bool |
| 87 | + If True, returns signed connectivity matrices. |
| 88 | +
|
| 89 | + Returns |
| 90 | + ------- |
| 91 | + dict |
| 92 | + Dictionary containing the connectivity matrices. |
| 93 | + """ |
| 94 | + practical_shape = list( |
| 95 | + np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape))) |
| 96 | + ) |
| 97 | + connectivity = {} |
| 98 | + for rank_idx in range(max_rank + 1): |
| 99 | + for connectivity_info in [ |
| 100 | + "incidence", |
| 101 | + "laplacian", |
| 102 | + "adjacency", |
| 103 | + ]: |
| 104 | + try: |
| 105 | + if connectivity_info == "laplacian": |
| 106 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 107 | + from_sparse(complex.laplacian_matrix(rank=rank_idx)) |
| 108 | + ) |
| 109 | + elif connectivity_info == "adjacency": |
| 110 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 111 | + from_sparse( |
| 112 | + getattr(complex, f"{connectivity_info}_matrix")( |
| 113 | + rank_idx, rank_idx + 1 |
| 114 | + ) |
| 115 | + ) |
| 116 | + ) |
| 117 | + else: # incidence |
| 118 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 119 | + from_sparse( |
| 120 | + getattr(complex, f"{connectivity_info}_matrix")( |
| 121 | + rank_idx - 1, rank_idx |
| 122 | + ) |
| 123 | + ) |
| 124 | + ) |
| 125 | + except ValueError: # noqa: PERF203 |
| 126 | + if connectivity_info == "incidence": |
| 127 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 128 | + generate_zero_sparse_connectivity( |
| 129 | + m=practical_shape[rank_idx - 1], |
| 130 | + n=practical_shape[rank_idx], |
| 131 | + ) |
| 132 | + ) |
| 133 | + else: |
| 134 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 135 | + generate_zero_sparse_connectivity( |
| 136 | + m=practical_shape[rank_idx], |
| 137 | + n=practical_shape[rank_idx], |
| 138 | + ) |
| 139 | + ) |
| 140 | + except AttributeError: |
| 141 | + if connectivity_info == "incidence": |
| 142 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 143 | + generate_zero_sparse_connectivity( |
| 144 | + m=practical_shape[rank_idx - 1], |
| 145 | + n=practical_shape[rank_idx], |
| 146 | + ) |
| 147 | + ) |
| 148 | + else: |
| 149 | + connectivity[f"{connectivity_info}_{rank_idx}"] = ( |
| 150 | + generate_zero_sparse_connectivity( |
| 151 | + m=practical_shape[rank_idx], |
| 152 | + n=practical_shape[rank_idx], |
| 153 | + ) |
| 154 | + ) |
| 155 | + connectivity["shape"] = practical_shape |
| 156 | + return connectivity |
| 157 | + |
| 158 | + |
74 | 159 | def get_complex_connectivity(complex, max_rank, signed=False):
|
75 | 160 | r"""Gets the connectivity matrices for the complex.
|
76 | 161 |
|
@@ -474,6 +559,46 @@ def load_point_cloud(
|
474 | 559 | return torch_geometric.data.Data(x=features, y=classes, pos=points)
|
475 | 560 |
|
476 | 561 |
|
| 562 | +def load_manual_simplicial_complex(): |
| 563 | + """Create a manual simplicial complex for testing purposes.""" |
| 564 | + num_feats = 2 |
| 565 | + one_cells = [i for i in range(5)] |
| 566 | + two_cells = [[0, 1], [0, 2], [1, 2], [1, 3], [2, 3], [0, 4], [2, 4]] |
| 567 | + three_cells = [[0, 1, 2], [1, 2, 3], [0, 2, 4]] |
| 568 | + incidence_1 = [ |
| 569 | + [1, 1, 0, 0, 0, 1, 0], |
| 570 | + [1, 0, 1, 1, 0, 0, 0], |
| 571 | + [0, 1, 1, 0, 1, 0, 1], |
| 572 | + [0, 0, 0, 1, 1, 0, 0], |
| 573 | + [0, 0, 0, 0, 0, 1, 1], |
| 574 | + ] |
| 575 | + incidence_2 = [ |
| 576 | + [1, 0, 0], |
| 577 | + [1, 0, 1], |
| 578 | + [1, 1, 0], |
| 579 | + [0, 1, 0], |
| 580 | + [0, 1, 0], |
| 581 | + [0, 0, 1], |
| 582 | + [0, 0, 1], |
| 583 | + ] |
| 584 | + |
| 585 | + y = [1] |
| 586 | + |
| 587 | + return torch_geometric.data.Data( |
| 588 | + x_0=torch.rand(len(one_cells), num_feats), |
| 589 | + x_1=torch.rand(len(two_cells), num_feats), |
| 590 | + x_2=torch.rand(len(three_cells), num_feats), |
| 591 | + incidence_0=torch.zeros((1, 5)).to_sparse(), |
| 592 | + adjacency_1=torch.zeros((len(one_cells), len(one_cells))).to_sparse(), |
| 593 | + adjacency_2=torch.zeros((len(two_cells), len(two_cells))).to_sparse(), |
| 594 | + adjacency_0=torch.zeros((5, 5)).to_sparse(), |
| 595 | + incidence_1=torch.tensor(incidence_1).to_sparse(), |
| 596 | + incidence_2=torch.tensor(incidence_2).to_sparse(), |
| 597 | + num_nodes=len(one_cells), |
| 598 | + y=torch.tensor(y), |
| 599 | + ) |
| 600 | + |
| 601 | + |
477 | 602 | def load_manual_graph():
|
478 | 603 | """Create a manual graph for testing purposes."""
|
479 | 604 | # Define the vertices (just 8 vertices)
|
|
0 commit comments