Skip to content

Commit 5a6b226

Browse files
authored
Merge pull request #86 from tvayer/master
[MRG] Gromov-Wasserstein closed form for linesearch and integration of Fused Gromov-Wasserstein This PR closes #82 Thank you @tvayer for all the work.
2 parents f66ab58 + 788a650 commit 5a6b226

File tree

9 files changed

+972
-51
lines changed

9 files changed

+972
-51
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ The contributors to this library are:
164164
* Erwan Vautier (Gromov-Wasserstein)
165165
* [Kilian Fatras](https://kilianfatras.github.io/)
166166
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
167+
* [Vayer Titouan](https://tvayer.github.io/)
167168

168169
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
169170

@@ -233,3 +234,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
233234
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
234235

235236
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
237+
238+
[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).

examples/plot_barycenter_fgw.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================
4+
Plot graphs' barycenter using FGW
5+
=================================
6+
7+
This example illustrates the computation barycenter of labeled graphs using FGW
8+
9+
Requires networkx >=2
10+
11+
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
12+
and Courty Nicolas
13+
"Optimal Transport for structured data with application on graphs"
14+
International Conference on Machine Learning (ICML). 2019.
15+
16+
"""
17+
18+
# Author: Titouan Vayer <[email protected]>
19+
#
20+
# License: MIT License
21+
22+
#%% load libraries
23+
import numpy as np
24+
import matplotlib.pyplot as plt
25+
import networkx as nx
26+
import math
27+
from scipy.sparse.csgraph import shortest_path
28+
import matplotlib.colors as mcol
29+
from matplotlib import cm
30+
from ot.gromov import fgw_barycenters
31+
#%% Graph functions
32+
33+
34+
def find_thresh(C, inf=0.5, sup=3, step=10):
35+
""" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
36+
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
37+
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
38+
and the original matrix.
39+
Parameters
40+
----------
41+
C : ndarray, shape (n_nodes,n_nodes)
42+
The structure matrix to threshold
43+
inf : float
44+
The beginning of the linesearch
45+
sup : float
46+
The end of the linesearch
47+
step : integer
48+
Number of thresholds tested
49+
"""
50+
dist = []
51+
search = np.linspace(inf, sup, step)
52+
for thresh in search:
53+
Cprime = sp_to_adjency(C, 0, thresh)
54+
SC = shortest_path(Cprime, method='D')
55+
SC[SC == float('inf')] = 100
56+
dist.append(np.linalg.norm(SC - C))
57+
return search[np.argmin(dist)], dist
58+
59+
60+
def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
61+
""" Thresholds the structure matrix in order to compute an adjency matrix.
62+
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
63+
Parameters
64+
----------
65+
C : ndarray, shape (n_nodes,n_nodes)
66+
The structure matrix to threshold
67+
threshinf : float
68+
The minimum value of distance from which the new value is set to 1
69+
threshsup : float
70+
The maximum value of distance from which the new value is set to 1
71+
Returns
72+
-------
73+
C : ndarray, shape (n_nodes,n_nodes)
74+
The threshold matrix. Each element is in {0,1}
75+
"""
76+
H = np.zeros_like(C)
77+
np.fill_diagonal(H, np.diagonal(C))
78+
C = C - H
79+
C = np.minimum(np.maximum(C, threshinf), threshsup)
80+
C[C == threshsup] = 0
81+
C[C != 0] = 1
82+
83+
return C
84+
85+
86+
def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
87+
""" Create a noisy circular graph
88+
"""
89+
g = nx.Graph()
90+
g.add_nodes_from(list(range(N)))
91+
for i in range(N):
92+
noise = float(np.random.normal(mu, sigma, 1))
93+
if with_noise:
94+
g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
95+
else:
96+
g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
97+
g.add_edge(i, i + 1)
98+
if structure_noise:
99+
randomint = np.random.randint(0, p)
100+
if randomint == 0:
101+
if i <= N - 3:
102+
g.add_edge(i, i + 2)
103+
if i == N - 2:
104+
g.add_edge(i, 0)
105+
if i == N - 1:
106+
g.add_edge(i, 1)
107+
g.add_edge(N, 0)
108+
noise = float(np.random.normal(mu, sigma, 1))
109+
if with_noise:
110+
g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
111+
else:
112+
g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
113+
return g
114+
115+
116+
def graph_colors(nx_graph, vmin=0, vmax=7):
117+
cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
118+
cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
119+
cpick.set_array([])
120+
val_map = {}
121+
for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
122+
val_map[k] = cpick.to_rgba(v)
123+
colors = []
124+
for node in nx_graph.nodes():
125+
colors.append(val_map[node])
126+
return colors
127+
128+
##############################################################################
129+
# Generate data
130+
# -------------
131+
132+
#%% circular dataset
133+
# We build a dataset of noisy circular graphs.
134+
# Noise is added on the structures by random connections and on the features by gaussian noise.
135+
136+
137+
np.random.seed(30)
138+
X0 = []
139+
for k in range(9):
140+
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
141+
142+
##############################################################################
143+
# Plot data
144+
# ---------
145+
146+
#%% Plot graphs
147+
148+
plt.figure(figsize=(8, 10))
149+
for i in range(len(X0)):
150+
plt.subplot(3, 3, i + 1)
151+
g = X0[i]
152+
pos = nx.kamada_kawai_layout(g)
153+
nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
154+
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
155+
plt.show()
156+
157+
##############################################################################
158+
# Barycenter computation
159+
# ----------------------
160+
161+
#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
162+
# Features distances are the euclidean distances
163+
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
164+
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
165+
Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
166+
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
167+
sizebary = 15 # we choose a barycenter with 15 nodes
168+
169+
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
170+
171+
##############################################################################
172+
# Plot Barycenter
173+
# -------------------------
174+
175+
#%% Create the barycenter
176+
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
177+
for i, v in enumerate(A.ravel()):
178+
bary.add_node(i, attr_name=v)
179+
180+
#%%
181+
pos = nx.kamada_kawai_layout(bary)
182+
nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
183+
plt.suptitle('Barycenter', fontsize=20)
184+
plt.show()

examples/plot_fgw.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==============================
4+
Plot Fused-gromov-Wasserstein
5+
==============================
6+
7+
This example illustrates the computation of FGW for 1D measures[18].
8+
9+
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
10+
and Courty Nicolas
11+
"Optimal Transport for structured data with application on graphs"
12+
International Conference on Machine Learning (ICML). 2019.
13+
14+
"""
15+
16+
# Author: Titouan Vayer <[email protected]>
17+
#
18+
# License: MIT License
19+
20+
import matplotlib.pyplot as pl
21+
import numpy as np
22+
import ot
23+
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
24+
25+
##############################################################################
26+
# Generate data
27+
# ---------
28+
29+
#%% parameters
30+
# We create two 1D random measures
31+
n = 20 # number of points in the first distribution
32+
n2 = 30 # number of points in the second distribution
33+
sig = 1 # std of first distribution
34+
sig2 = 0.1 # std of second distribution
35+
36+
np.random.seed(0)
37+
38+
phi = np.arange(n)[:, None]
39+
xs = phi + sig * np.random.randn(n, 1)
40+
ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
41+
42+
phi2 = np.arange(n2)[:, None]
43+
xt = phi2 + sig * np.random.randn(n2, 1)
44+
yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
45+
yt = yt[::-1, :]
46+
47+
p = ot.unif(n)
48+
q = ot.unif(n2)
49+
50+
##############################################################################
51+
# Plot data
52+
# ---------
53+
54+
#%% plot the distributions
55+
56+
pl.close(10)
57+
pl.figure(10, (7, 7))
58+
59+
pl.subplot(2, 1, 1)
60+
61+
pl.scatter(ys, xs, c=phi, s=70)
62+
pl.ylabel('Feature value a', fontsize=20)
63+
pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
64+
pl.xticks(())
65+
pl.yticks(())
66+
pl.subplot(2, 1, 2)
67+
pl.scatter(yt, xt, c=phi2, s=70)
68+
pl.xlabel('coordinates x/y', fontsize=25)
69+
pl.ylabel('Feature value b', fontsize=20)
70+
pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
71+
pl.yticks(())
72+
pl.tight_layout()
73+
pl.show()
74+
75+
##############################################################################
76+
# Create structure matrices and across-feature distance matrix
77+
# ---------
78+
79+
#%% Structure matrices and across-features distance matrix
80+
C1 = ot.dist(xs)
81+
C2 = ot.dist(xt)
82+
M = ot.dist(ys, yt)
83+
w1 = ot.unif(C1.shape[0])
84+
w2 = ot.unif(C2.shape[0])
85+
Got = ot.emd([], [], M)
86+
87+
##############################################################################
88+
# Plot matrices
89+
# ---------
90+
91+
#%%
92+
cmap = 'Reds'
93+
pl.close(10)
94+
pl.figure(10, (5, 5))
95+
fs = 15
96+
l_x = [0, 5, 10, 15]
97+
l_y = [0, 5, 10, 15, 20, 25]
98+
gs = pl.GridSpec(5, 5)
99+
100+
ax1 = pl.subplot(gs[3:, :2])
101+
102+
pl.imshow(C1, cmap=cmap, interpolation='nearest')
103+
pl.title("$C_1$", fontsize=fs)
104+
pl.xlabel("$k$", fontsize=fs)
105+
pl.ylabel("$i$", fontsize=fs)
106+
pl.xticks(l_x)
107+
pl.yticks(l_x)
108+
109+
ax2 = pl.subplot(gs[:3, 2:])
110+
111+
pl.imshow(C2, cmap=cmap, interpolation='nearest')
112+
pl.title("$C_2$", fontsize=fs)
113+
pl.ylabel("$l$", fontsize=fs)
114+
#pl.ylabel("$l$",fontsize=fs)
115+
pl.xticks(())
116+
pl.yticks(l_y)
117+
ax2.set_aspect('auto')
118+
119+
ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
120+
pl.imshow(M, cmap=cmap, interpolation='nearest')
121+
pl.yticks(l_x)
122+
pl.xticks(l_y)
123+
pl.ylabel("$i$", fontsize=fs)
124+
pl.title("$M_{AB}$", fontsize=fs)
125+
pl.xlabel("$j$", fontsize=fs)
126+
pl.tight_layout()
127+
ax3.set_aspect('auto')
128+
pl.show()
129+
130+
##############################################################################
131+
# Compute FGW/GW
132+
# ---------
133+
134+
#%% Computing FGW and GW
135+
alpha = 1e-3
136+
137+
ot.tic()
138+
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
139+
ot.toc()
140+
141+
#%reload_ext WGW
142+
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
143+
144+
##############################################################################
145+
# Visualize transport matrices
146+
# ---------
147+
148+
#%% visu OT matrix
149+
cmap = 'Blues'
150+
fs = 15
151+
pl.figure(2, (13, 5))
152+
pl.clf()
153+
pl.subplot(1, 3, 1)
154+
pl.imshow(Got, cmap=cmap, interpolation='nearest')
155+
#pl.xlabel("$y$",fontsize=fs)
156+
pl.ylabel("$i$", fontsize=fs)
157+
pl.xticks(())
158+
159+
pl.title('Wasserstein ($M$ only)')
160+
161+
pl.subplot(1, 3, 2)
162+
pl.imshow(Gg, cmap=cmap, interpolation='nearest')
163+
pl.title('Gromov ($C_1,C_2$ only)')
164+
pl.xticks(())
165+
pl.subplot(1, 3, 3)
166+
pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
167+
pl.title('FGW ($M+C_1,C_2$)')
168+
169+
pl.xlabel("$j$", fontsize=fs)
170+
pl.ylabel("$i$", fontsize=fs)
171+
172+
pl.tight_layout()
173+
pl.show()

ot/bregman.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Author: Remi Flamary <[email protected]>
77
# Nicolas Courty <[email protected]>
88
# Kilian Fatras <[email protected]>
9+
# Titouan Vayer <[email protected]>
910
#
1011
# License: MIT License
1112

0 commit comments

Comments
 (0)