Skip to content

Commit e0f2c36

Browse files
committed
Merge branch 'main' of github.com:ManuelLecha/challenge-icml-2024 into ManuelLecha-main
2 parents c052610 + 71054c9 commit e0f2c36

File tree

12 files changed

+1940
-48
lines changed

12 files changed

+1940
-48
lines changed

configs/models/combinatorial/spcc.yaml

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Defines the default parameters of the implemented transform.
2+
transform_type: 'lifting'
3+
transform_name: "SimplicialPathsLifting"
4+
d1: 2
5+
d2: 2
6+
q: 1
7+
i: 1
8+
j: 2
9+
complex_dim: 2
10+
chunk_size: 1024
11+
threshold: 1

modules/models/combinatorial/spcc.py

+287
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import torch
2+
from topomodelx.base.aggregation import Aggregation
3+
from topomodelx.nn.combinatorial.hmc_layer import HBNS, HBS
4+
5+
6+
class SPCCLayer(torch.nn.Module):
7+
r"""Simplicial Paths Combinatorial Complex Layer
8+
9+
We aim to exploit the inherent directed graph to induce higher-order motifs
10+
in the form of simplicial paths.
11+
12+
This simple layer is only build for testing purposes: We consider a
13+
combinatorial complex with 0-dimensional cells (vertices), 1-dimensional
14+
cells (edges), 2-dimensional cells (collections of nodes contained in
15+
simplicial paths)
16+
17+
Message passing: 0-dimensional cells (vertices) receive messages from
18+
0-dimensional cells (vertices) and from 2-dimensional cells (collections
19+
of nodes contained in simplicial paths).In the first case, adjacency
20+
matrices are used. In the second case, the incidence matrix from
21+
dimension 1 to dimension 2 is used.
22+
23+
Notes
24+
-----
25+
This is a simple layer only build for testing purposes.
26+
27+
Parameters
28+
----------
29+
in_channels : list of int
30+
Dimension of input features on vertices (0-cells), and simplicial
31+
paths (3-cells). The length of the list must be 2.
32+
33+
out_channels : list of int
34+
Dimension of output features on vertices (0-cells) and simplicial
35+
paths (3-cells). The length of the list must be 2.
36+
37+
negative_slope : float
38+
Negative slope of LeakyReLU used to compute the attention
39+
coefficients.
40+
41+
softmax_attention : bool, optional
42+
Whether to use softmax attention. If True, the attention
43+
coefficients are normalized by rows using softmax over all the
44+
columns that are not zero in the associated neighborhood
45+
matrix. If False, the normalization is done by dividing by the
46+
sum of the values of the coefficients in its row whose columns
47+
are not zero in the associated neighborhood matrix. Default is
48+
False.
49+
50+
update_func_attention : string, optional
51+
Activation function used in the attention block. If None,
52+
no activation function is applied. Default is None.
53+
54+
update_func_aggregation : string, optional
55+
Function used to aggregate the messages computed in each
56+
attention block. If None, the messages are aggregated by summing
57+
them. Default is None.
58+
59+
initialization : {'xavier_uniform', 'xavier_normal'}, optional
60+
Initialization method for the weights of the attention layers.
61+
Default is 'xavier_uniform'.
62+
"""
63+
64+
def __init__(
65+
self,
66+
in_channels: list[int],
67+
out_channels: list[int],
68+
negative_slope: float,
69+
softmax_attention=False,
70+
update_func_attention=None,
71+
update_func_aggregation=None,
72+
initialization="xavier_uniform",
73+
):
74+
super().__init__()
75+
super().__init__()
76+
77+
assert len(in_channels) == 2 and len(out_channels) == 2
78+
79+
in_channels_0, in_channels_2 = in_channels
80+
out_channels_0, out_channels_2 = out_channels
81+
82+
self.hbs_0 = HBS(
83+
source_in_channels=in_channels_0,
84+
source_out_channels=out_channels_0,
85+
negative_slope=negative_slope,
86+
softmax=softmax_attention,
87+
update_func=update_func_attention,
88+
initialization=initialization,
89+
)
90+
91+
self.hbns_0_2 = HBNS(
92+
source_in_channels=in_channels_2,
93+
source_out_channels=out_channels_2,
94+
target_in_channels=in_channels_0,
95+
target_out_channels=out_channels_0,
96+
negative_slope=negative_slope,
97+
softmax=softmax_attention,
98+
update_func=update_func_attention,
99+
initialization=initialization,
100+
)
101+
102+
self.aggr = Aggregation(aggr_func="sum", update_func=update_func_aggregation)
103+
104+
def forward(self, x_0, x_2, adjacency_0, incidence_0_2):
105+
r"""Forward pass.
106+
107+
In both message passing levels, :math:`\phi_u` and :math:`\phi_a`
108+
represent common activation functions within and between neighborhood
109+
aggregations. Both are passed to the constructor of the class as
110+
arguments update_func_attention and update_func_aggregation,
111+
respectively.
112+
113+
Parameters
114+
----------
115+
x_0 : torch.Tensor, shape=[n_0_cells, in_channels[0]]
116+
Input features on the 0-cells (vertices) of the combinatorial
117+
complex.
118+
x_2 : torch.Tensor, shape=[n_3_cells, in_channels[3]]
119+
Input features on the 3-cells (simplicial paths) of the combinatorial
120+
complex.
121+
122+
adjacency_0 : torch.sparse
123+
shape=[n_0_cells, n_0_cells]
124+
Neighborhood matrix mapping 0-cells to 0-cells (A_0_up).
125+
126+
incidence_0_2 : torch.sparse
127+
shape=[n_0_cells, n_3_cells]
128+
Neighborhood matrix mapping 3-cells to 0-cells (B_3).
129+
130+
Returns
131+
-------
132+
_ : torch.Tensor, shape=[1, num_classes]
133+
Output prediction on the entire cell complex.
134+
"""
135+
136+
# Computing messages from the Simplicial Path Attention Block
137+
138+
x_0_to_0 = self.hbs_0(x_0, adjacency_0)
139+
x_0_to_2, x_2_to_0 = self.hbns_0_2(x_2, x_0, incidence_0_2)
140+
141+
x_0 = self.aggr([x_0_to_0, x_2_to_0])
142+
x_2 = self.aggr([x_0_to_2])
143+
144+
return x_0, x_2
145+
146+
147+
class SPCC(torch.nn.Module):
148+
"""Simplicial Paths Combinatorial Complex Attention Network.
149+
150+
Parameters
151+
----------
152+
channels_per_layer : list of list of list of int
153+
Number of input, and output channels for each
154+
Simplicial Paths Combinatorial Complex Attention Layer.
155+
The length of the list corresponds to the number of layers.
156+
Each element k of the list is a list consisting of other 2
157+
lists. The first list contains the number of input channels for
158+
each input signal (nodes, sp_cells) for the k-th layer.
159+
The second list contains the number of output channels for
160+
each input signal (nodes, sp_cells) for the k-th layer.
161+
negative_slope : float
162+
Negative slope for the LeakyReLU activation.
163+
update_func_attention : str
164+
Update function for the attention mechanism. Default is "relu".
165+
update_func_aggregation : str
166+
Update function for the aggregation mechanism. Default is "relu".
167+
"""
168+
169+
def __init__(
170+
self,
171+
channels_per_layer,
172+
negative_slope=0.2,
173+
update_func_attention="relu",
174+
update_func_aggregation="relu",
175+
) -> None:
176+
def check_channels_consistency():
177+
"""Check that the number of input, and output
178+
channels is consistent."""
179+
assert len(channels_per_layer) > 0
180+
for i in range(len(channels_per_layer) - 1):
181+
assert channels_per_layer[i][2][0] == channels_per_layer[i + 1][0][0]
182+
assert channels_per_layer[i][2][1] == channels_per_layer[i + 1][0][1]
183+
184+
super().__init__()
185+
check_channels_consistency()
186+
self.layers = torch.nn.ModuleList(
187+
[
188+
SPCCLayer(
189+
in_channels=in_channels,
190+
out_channels=out_channels,
191+
negative_slope=negative_slope,
192+
softmax_attention=True,
193+
update_func_attention=update_func_attention,
194+
update_func_aggregation=update_func_aggregation,
195+
)
196+
for in_channels, out_channels in channels_per_layer
197+
]
198+
)
199+
200+
def forward(
201+
self,
202+
x_0,
203+
x_2,
204+
neighborhood_0_to_0,
205+
neighborhood_0_to_2,
206+
) -> tuple[torch.Tensor, torch.Tensor]:
207+
"""Forward pass.
208+
209+
Parameters
210+
----------
211+
x_0 : torch.Tensor
212+
Input features on nodes.
213+
x_2 : torch.Tensor
214+
Input features on simplicial paths.
215+
neighborhood_0_to_0 : torch.Tensor
216+
Adjacency matrix from nodes to nodes.
217+
neighborhood_0_to_2 : torch.Tensor
218+
Incidence matrix from nodes to simplicial path cells.
219+
220+
Returns
221+
-------
222+
torch.Tensor, shape = (n_nodes, out_channels_0)
223+
Final hidden states of the nodes (0-cells).
224+
torch.Tensor, shape = (n_spcells, out_channels_3)
225+
Final hidden states of the faces (3-cells).
226+
"""
227+
for layer in self.layers:
228+
x_0, x_2 = layer(x_0, x_2, neighborhood_0_to_0, neighborhood_0_to_2)
229+
230+
return x_0, x_2
231+
232+
233+
class SPCCNN(torch.nn.Module):
234+
"""Simplicial Paths Combinatorial Complex Attention Network Model For
235+
Node Classification.
236+
237+
Parameters
238+
----------
239+
channels_per_layer : list of list of list of int
240+
Number of input, and output channels for each
241+
Simplicial Paths Combinatorial Complex Attention Layer.
242+
The length of the list corresponds to the number of layers.
243+
Each element k of the list is a list consisting of other 2
244+
lists. The first list contains the number of input channels for
245+
each input signal (nodes, sp_cells) for the k-th layer.
246+
The second list contains the number of output channels for
247+
each input signal (nodes, sp_cells) for the k-th layer.
248+
out_channels_0 : int
249+
Number of output channels for the 0-cells (classes)
250+
negative_slope : float
251+
Negative slope for the LeakyReLU activation.
252+
253+
Returns
254+
-------
255+
torch.Tensor, shape = (n_nodes, out_channels_0)
256+
Final probability states of the nodes (0-cells).
257+
"""
258+
259+
def __init__(
260+
self,
261+
channels_per_layer,
262+
out_channels_0,
263+
negative_slope=0.2,
264+
):
265+
super().__init__()
266+
self.base_model = SPCC(
267+
channels_per_layer,
268+
negative_slope,
269+
)
270+
271+
self.linear = torch.nn.Linear(channels_per_layer[-1][1][0], out_channels_0)
272+
273+
def forward(self, data):
274+
x_0 = data["x_0"]
275+
x_2 = data["x_2"]
276+
neighborhood_0_to_0 = data["adjacency_0_1"]
277+
neighborhood_0_to_2 = data["incidence_0_2"]
278+
279+
x_0, _ = self.base_model(
280+
x_0,
281+
x_2,
282+
neighborhood_0_to_0,
283+
neighborhood_0_to_2,
284+
)
285+
286+
x_0 = self.linear(x_0)
287+
return torch.softmax(x_0, dim=1)

modules/transforms/data_transform.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from modules.transforms.liftings.graph2combinatorial.ring_close_atoms_lifting import (
1818
CombinatorialRingCloseAtomsLifting,
1919
)
20+
from modules.transforms.liftings.graph2combinatorial.sp_lifting import (
21+
SimplicialPathsLifting,
22+
)
2023
from modules.transforms.liftings.graph2hypergraph.expander_graph_lifting import (
2124
ExpanderGraphLifting,
2225
)
@@ -88,6 +91,7 @@
8891
# Graph -> Combinatorial Complex
8992
"CombinatorialRingCloseAtomsLifting": CombinatorialRingCloseAtomsLifting,
9093
"CurveLifting": CurveLifting,
94+
"SimplicialPathsLifting": SimplicialPathsLifting,
9195
# Point Cloud -> Simplicial Complex,
9296
"AlphaComplexLifting": AlphaComplexLifting,
9397
# Point-cloud -> Simplicial Complex

0 commit comments

Comments
 (0)