|
| 1 | +from itertools import permutations |
| 2 | +from typing import ClassVar |
| 3 | + |
| 4 | +import networkx as nx |
| 5 | +import torch |
| 6 | +import torch_geometric |
| 7 | +from toponetx.classes import CellComplex |
| 8 | + |
| 9 | +from modules.transforms.liftings.graph2cell.base import Graph2CellLifting |
| 10 | +from modules.utils.utils import edge_cycle_to_vertex_cycle |
| 11 | + |
| 12 | +Vertex = int |
| 13 | +Edge = tuple[Vertex, Vertex] |
| 14 | +ConfigurationTuple = tuple[Vertex | Edge] |
| 15 | + |
| 16 | + |
| 17 | +class DiscreteConfigurationComplexLifting(Graph2CellLifting): |
| 18 | + r"""Lifts graphs to cell complexes by generating the k-th *discrete configuration complex* $D_k(G)$ of the graph. This is a cube complex, which is similar to a simplicial complex except each n-dimensional cell is homeomorphic to a n-dimensional cube rather than an n-dimensional simplex. |
| 19 | +
|
| 20 | + The discrete configuration complex of order k consists of all sets of k unique edges or vertices of $G$, with the additional constraint that if an edge e is in a cell, then neither of the endpoints of e are in the cell. For examples of different graphs and their configuration complexes, see the tutorial. |
| 21 | +
|
| 22 | + Note that since TopoNetx only supports cell complexes of dimension 2, if you generate a configuration complex of order k > 2 this will only produce the 2-skeleton. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + k: int, |
| 27 | + The order of the configuration complex, i.e. the number of 'agents' in a single configuration. |
| 28 | + preserve_edge_attr : bool, optional |
| 29 | + Whether to preserve edge attributes. Default is True. |
| 30 | + feature_aggregation: str, optional |
| 31 | + For a k-agent configuration, the method by which the features are aggregated. Can be "mean", "sum", or "concat". Default is "concat". |
| 32 | + **kwargs : optional |
| 33 | + Additional arguments for the class. |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + k: int, |
| 39 | + preserve_edge_attr: bool = True, |
| 40 | + feature_aggregation="concat", |
| 41 | + **kwargs, |
| 42 | + ): |
| 43 | + self.k = k |
| 44 | + self.complex_dim = 2 |
| 45 | + if feature_aggregation not in ["mean", "sum", "concat"]: |
| 46 | + raise ValueError( |
| 47 | + "feature_aggregation must be one of 'mean', 'sum', 'concat'" |
| 48 | + ) |
| 49 | + self.feature_aggregation = feature_aggregation |
| 50 | + super().__init__(preserve_edge_attr=preserve_edge_attr, **kwargs) |
| 51 | + |
| 52 | + def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: |
| 53 | + r"""Applies the full lifting (topology + features) to the input data. |
| 54 | +
|
| 55 | + Parameters |
| 56 | + ---------- |
| 57 | + data : torch_geometric.data.Data |
| 58 | + The input data to be lifted. |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | + torch_geometric.data.Data |
| 63 | + The lifted data. |
| 64 | + """ |
| 65 | + # Unlike the base class, we do not pass the initial data to the final data |
| 66 | + # This is because the configuration complex has a completely different 1-skeleton from the original graph |
| 67 | + lifted_topology = self.lift_topology(data) |
| 68 | + lifted_topology = self.feature_lifting(lifted_topology) |
| 69 | + return torch_geometric.data.Data(y=data.y, **lifted_topology) |
| 70 | + |
| 71 | + def lift_topology(self, data: torch_geometric.data.Data) -> dict: |
| 72 | + r"""Generates the cubical complex of discrete graph configurations. |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ---------- |
| 76 | + data : torch_geometric.data.Data |
| 77 | + The input data to be lifted. |
| 78 | +
|
| 79 | + Returns |
| 80 | + ------- |
| 81 | + dict |
| 82 | + The lifted topology. |
| 83 | + """ |
| 84 | + G = self._generate_graph_from_data(data) |
| 85 | + if G.is_directed(): |
| 86 | + raise ValueError("Directed Graphs are not supported.") |
| 87 | + |
| 88 | + Configuration = generate_configuration_class( |
| 89 | + G, self.feature_aggregation, self.contains_edge_attr |
| 90 | + ) |
| 91 | + |
| 92 | + # The vertices of the configuration complex are just tuples of k vertices |
| 93 | + for dim_0_configuration_tuple in permutations(G, self.k): |
| 94 | + configuration = Configuration(dim_0_configuration_tuple) |
| 95 | + configuration.generate_upwards_neighbors() |
| 96 | + |
| 97 | + cells = {i: [] for i in range(self.k + 1)} |
| 98 | + for conf in Configuration.instances.values(): |
| 99 | + features = conf.features() |
| 100 | + attrs = {"features": features} if features is not None else {} |
| 101 | + cell = (conf.contents, attrs) |
| 102 | + cells[conf.dim].append(cell) |
| 103 | + |
| 104 | + # TopoNetX only supports cells of dimension <= 2 |
| 105 | + cc = CellComplex() |
| 106 | + for node, attrs in cells[0]: |
| 107 | + cc.add_node(node, **attrs) |
| 108 | + for edge, attrs in cells[1]: |
| 109 | + cc.add_edge(edge[0], edge[1], **attrs) |
| 110 | + for cell, attrs in cells[2]: |
| 111 | + cell_vertices = edge_cycle_to_vertex_cycle(cell) |
| 112 | + cc.add_cell(cell_vertices, rank=2, **attrs) |
| 113 | + |
| 114 | + return self._get_lifted_topology(cc, G) |
| 115 | + |
| 116 | + |
| 117 | +def generate_configuration_class( |
| 118 | + graph: nx.Graph, feature_aggregation: str, edge_features: bool |
| 119 | +): |
| 120 | + """Class factory for the Configuration class.""" |
| 121 | + |
| 122 | + class Configuration: |
| 123 | + """Represents a single legal configuration of k agents on a graph G. A legal configuration is a tuple of k edges and vertices of G where all the vertices and endpoints are **distinct** i.e. no two edges sharing an endpoint can simultaneously be in the configuration, and adjacent (edge, vertex) pair can be contained in the configuration. Each configuration corresponds to a cell, and the number of edges in the configuration is the dimension. |
| 124 | +
|
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + k : int, optional. |
| 128 | + The order of the configuration complex, or the number of 'points' in the configuration. |
| 129 | + graph: nx.Graph. |
| 130 | + The graph on which the configurations are defined. |
| 131 | + """ |
| 132 | + |
| 133 | + instances: ClassVar[dict[ConfigurationTuple, "Configuration"]] = {} |
| 134 | + |
| 135 | + def __new__(cls, configuration_tuple: ConfigurationTuple): |
| 136 | + # Ensure that a configuration tuple corresponds to a *unique* configuration object |
| 137 | + key = configuration_tuple |
| 138 | + if key not in cls.instances: |
| 139 | + cls.instances[key] = super().__new__(cls) |
| 140 | + |
| 141 | + return cls.instances[key] |
| 142 | + |
| 143 | + def __init__(self, configuration_tuple: ConfigurationTuple) -> None: |
| 144 | + # If this object was already initialized earlier, maintain current state |
| 145 | + if hasattr(self, "initialized"): |
| 146 | + return |
| 147 | + |
| 148 | + self.initialized = True |
| 149 | + self.configuration_tuple = configuration_tuple |
| 150 | + self.neighborhood = set() |
| 151 | + self.dim = 0 |
| 152 | + for agent in configuration_tuple: |
| 153 | + if isinstance(agent, Vertex): |
| 154 | + self.neighborhood.add(agent) |
| 155 | + else: |
| 156 | + self.neighborhood.update(set(agent)) |
| 157 | + self.dim += 1 |
| 158 | + |
| 159 | + if self.dim == 0: |
| 160 | + self.contents = configuration_tuple |
| 161 | + else: |
| 162 | + self.contents = [] |
| 163 | + |
| 164 | + self._upwards_neighbors_generated = False |
| 165 | + |
| 166 | + def features(self): |
| 167 | + """Generate the features for the configuration by combining the edge and vertex features.""" |
| 168 | + features = [] |
| 169 | + for agent in self.configuration_tuple: |
| 170 | + if isinstance(agent, Vertex): |
| 171 | + features.append(graph.nodes[agent]["features"]) |
| 172 | + elif edge_features: |
| 173 | + features.append(graph.edges[agent]["features"]) |
| 174 | + |
| 175 | + if not features: |
| 176 | + return None |
| 177 | + |
| 178 | + if feature_aggregation == "mean": |
| 179 | + try: |
| 180 | + return torch.stack(features, dim=0).mean(dim=0) |
| 181 | + except Exception as e: |
| 182 | + raise ValueError( |
| 183 | + "Failed to mean feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features." |
| 184 | + ) from e |
| 185 | + elif feature_aggregation == "sum": |
| 186 | + try: |
| 187 | + return torch.stack(features, dim=0).sum(dim=0) |
| 188 | + except Exception as e: |
| 189 | + raise ValueError( |
| 190 | + "Failed to sum feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features." |
| 191 | + ) from e |
| 192 | + elif feature_aggregation == "concat": |
| 193 | + return torch.concatenate(features, dim=-1) |
| 194 | + else: |
| 195 | + raise ValueError( |
| 196 | + f"Unrecognized feature_aggregation: {feature_aggregation}" |
| 197 | + ) |
| 198 | + |
| 199 | + def generate_upwards_neighbors(self): |
| 200 | + """For the configuration self of dimension d, generate the configurations of dimension d+1 containing it.""" |
| 201 | + if self._upwards_neighbors_generated: |
| 202 | + return |
| 203 | + self._upwards_neighbors_generated = True |
| 204 | + for i, agent in enumerate(self.configuration_tuple): |
| 205 | + if isinstance(agent, Vertex): |
| 206 | + for neighbor in graph[agent]: |
| 207 | + self._generate_single_neighbor(i, agent, neighbor) |
| 208 | + |
| 209 | + def _generate_single_neighbor( |
| 210 | + self, index: int, vertex_agent: int, neighbor: int |
| 211 | + ): |
| 212 | + """Generate a configuration containing self by moving an agent from a vertex onto an edge.""" |
| 213 | + # If adding the edge (vertex_agent, neighbor) would produce an illegal configuration, ignore it |
| 214 | + if neighbor in self.neighborhood: |
| 215 | + return |
| 216 | + |
| 217 | + # We always orient edges (min -> max) to maintain uniqueness of configuration tuples |
| 218 | + new_edge = (min(vertex_agent, neighbor), max(vertex_agent, neighbor)) |
| 219 | + |
| 220 | + # Remove the vertex at index and replace it with new edge |
| 221 | + new_configuration_tuple = ( |
| 222 | + *self.configuration_tuple[:index], |
| 223 | + new_edge, |
| 224 | + *self.configuration_tuple[index + 1 :], |
| 225 | + ) |
| 226 | + new_configuration = Configuration(new_configuration_tuple) |
| 227 | + new_configuration.contents.append(self.contents) |
| 228 | + new_configuration.generate_upwards_neighbors() |
| 229 | + |
| 230 | + return Configuration |
0 commit comments