|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +================================= |
| 4 | +Plot graphs' barycenter using FGW |
| 5 | +================================= |
| 6 | +
|
| 7 | +This example illustrates the computation barycenter of labeled graphs using FGW |
| 8 | +
|
| 9 | +Requires networkx >=2 |
| 10 | +
|
| 11 | +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain |
| 12 | + and Courty Nicolas |
| 13 | + "Optimal Transport for structured data with application on graphs" |
| 14 | + International Conference on Machine Learning (ICML). 2019. |
| 15 | +
|
| 16 | +""" |
| 17 | + |
| 18 | +# Author: Titouan Vayer <[email protected]> |
| 19 | +# |
| 20 | +# License: MIT License |
| 21 | + |
| 22 | +#%% load libraries |
| 23 | +import numpy as np |
| 24 | +import matplotlib.pyplot as plt |
| 25 | +import networkx as nx |
| 26 | +import math |
| 27 | +from scipy.sparse.csgraph import shortest_path |
| 28 | +import matplotlib.colors as mcol |
| 29 | +from matplotlib import cm |
| 30 | +from ot.gromov import fgw_barycenters |
| 31 | +#%% Graph functions |
| 32 | + |
| 33 | + |
| 34 | +def find_thresh(C, inf=0.5, sup=3, step=10): |
| 35 | + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected |
| 36 | + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. |
| 37 | + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix |
| 38 | + and the original matrix. |
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + C : ndarray, shape (n_nodes,n_nodes) |
| 42 | + The structure matrix to threshold |
| 43 | + inf : float |
| 44 | + The beginning of the linesearch |
| 45 | + sup : float |
| 46 | + The end of the linesearch |
| 47 | + step : integer |
| 48 | + Number of thresholds tested |
| 49 | + """ |
| 50 | + dist = [] |
| 51 | + search = np.linspace(inf, sup, step) |
| 52 | + for thresh in search: |
| 53 | + Cprime = sp_to_adjency(C, 0, thresh) |
| 54 | + SC = shortest_path(Cprime, method='D') |
| 55 | + SC[SC == float('inf')] = 100 |
| 56 | + dist.append(np.linalg.norm(SC - C)) |
| 57 | + return search[np.argmin(dist)], dist |
| 58 | + |
| 59 | + |
| 60 | +def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): |
| 61 | + """ Thresholds the structure matrix in order to compute an adjency matrix. |
| 62 | + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 |
| 63 | + Parameters |
| 64 | + ---------- |
| 65 | + C : ndarray, shape (n_nodes,n_nodes) |
| 66 | + The structure matrix to threshold |
| 67 | + threshinf : float |
| 68 | + The minimum value of distance from which the new value is set to 1 |
| 69 | + threshsup : float |
| 70 | + The maximum value of distance from which the new value is set to 1 |
| 71 | + Returns |
| 72 | + ------- |
| 73 | + C : ndarray, shape (n_nodes,n_nodes) |
| 74 | + The threshold matrix. Each element is in {0,1} |
| 75 | + """ |
| 76 | + H = np.zeros_like(C) |
| 77 | + np.fill_diagonal(H, np.diagonal(C)) |
| 78 | + C = C - H |
| 79 | + C = np.minimum(np.maximum(C, threshinf), threshsup) |
| 80 | + C[C == threshsup] = 0 |
| 81 | + C[C != 0] = 1 |
| 82 | + |
| 83 | + return C |
| 84 | + |
| 85 | + |
| 86 | +def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): |
| 87 | + """ Create a noisy circular graph |
| 88 | + """ |
| 89 | + g = nx.Graph() |
| 90 | + g.add_nodes_from(list(range(N))) |
| 91 | + for i in range(N): |
| 92 | + noise = float(np.random.normal(mu, sigma, 1)) |
| 93 | + if with_noise: |
| 94 | + g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) |
| 95 | + else: |
| 96 | + g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) |
| 97 | + g.add_edge(i, i + 1) |
| 98 | + if structure_noise: |
| 99 | + randomint = np.random.randint(0, p) |
| 100 | + if randomint == 0: |
| 101 | + if i <= N - 3: |
| 102 | + g.add_edge(i, i + 2) |
| 103 | + if i == N - 2: |
| 104 | + g.add_edge(i, 0) |
| 105 | + if i == N - 1: |
| 106 | + g.add_edge(i, 1) |
| 107 | + g.add_edge(N, 0) |
| 108 | + noise = float(np.random.normal(mu, sigma, 1)) |
| 109 | + if with_noise: |
| 110 | + g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) |
| 111 | + else: |
| 112 | + g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) |
| 113 | + return g |
| 114 | + |
| 115 | + |
| 116 | +def graph_colors(nx_graph, vmin=0, vmax=7): |
| 117 | + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) |
| 118 | + cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') |
| 119 | + cpick.set_array([]) |
| 120 | + val_map = {} |
| 121 | + for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): |
| 122 | + val_map[k] = cpick.to_rgba(v) |
| 123 | + colors = [] |
| 124 | + for node in nx_graph.nodes(): |
| 125 | + colors.append(val_map[node]) |
| 126 | + return colors |
| 127 | + |
| 128 | +############################################################################## |
| 129 | +# Generate data |
| 130 | +# ------------- |
| 131 | + |
| 132 | +#%% circular dataset |
| 133 | +# We build a dataset of noisy circular graphs. |
| 134 | +# Noise is added on the structures by random connections and on the features by gaussian noise. |
| 135 | + |
| 136 | + |
| 137 | +np.random.seed(30) |
| 138 | +X0 = [] |
| 139 | +for k in range(9): |
| 140 | + X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) |
| 141 | + |
| 142 | +############################################################################## |
| 143 | +# Plot data |
| 144 | +# --------- |
| 145 | + |
| 146 | +#%% Plot graphs |
| 147 | + |
| 148 | +plt.figure(figsize=(8, 10)) |
| 149 | +for i in range(len(X0)): |
| 150 | + plt.subplot(3, 3, i + 1) |
| 151 | + g = X0[i] |
| 152 | + pos = nx.kamada_kawai_layout(g) |
| 153 | + nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) |
| 154 | +plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) |
| 155 | +plt.show() |
| 156 | + |
| 157 | +############################################################################## |
| 158 | +# Barycenter computation |
| 159 | +# ---------------------- |
| 160 | + |
| 161 | +#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph |
| 162 | +# Features distances are the euclidean distances |
| 163 | +Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] |
| 164 | +ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] |
| 165 | +Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] |
| 166 | +lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() |
| 167 | +sizebary = 15 # we choose a barycenter with 15 nodes |
| 168 | + |
| 169 | +A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) |
| 170 | + |
| 171 | +############################################################################## |
| 172 | +# Plot Barycenter |
| 173 | +# ------------------------- |
| 174 | + |
| 175 | +#%% Create the barycenter |
| 176 | +bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) |
| 177 | +for i, v in enumerate(A.ravel()): |
| 178 | + bary.add_node(i, attr_name=v) |
| 179 | + |
| 180 | +#%% |
| 181 | +pos = nx.kamada_kawai_layout(bary) |
| 182 | +nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) |
| 183 | +plt.suptitle('Barycenter', fontsize=20) |
| 184 | +plt.show() |
0 commit comments