Skip to content

Commit 2fe69eb

Browse files
authored
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints
1 parent 9c6ac88 commit 2fe69eb

File tree

6 files changed

+460
-31
lines changed

6 files changed

+460
-31
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples):
2626
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
2727
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
2828
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
29-
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
29+
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
3030
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
3131
* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
3232
* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
@@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t
295295
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
296296
Machine Learning (pp. 4104-4113). PMLR.
297297

298-
[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
299-
Conference on Machine Learning, PMLR 119:4692-4701, 2020
298+
[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
299+
Conference on Machine Learning, PMLR 119:4692-4701, 2020
300+
301+
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
302+
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
r"""
2+
=================================
3+
Optimizing the Gromov-Wasserstein distance with PyTorch
4+
=================================
5+
6+
In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein
7+
(GW) loss between two graphs expressed as empirical distribution.
8+
9+
In the first example we optimize the weights on the node of a simple template
10+
graph so that it minimizes the GW with a given Stochastic Block Model graph.
11+
We can see that this actually recovers the proportion of classes in the SBM
12+
and allows for an accurate clustering of the nodes using the GW optimal plan.
13+
14+
In a second example we optimize simultaneously the weights and the sructure of
15+
the template graph which allows us to perform graph compression and to recover
16+
other properties of the SBM.
17+
18+
The backend actually uses the gradients expressed in [38] to optimize the
19+
weights.
20+
21+
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
22+
Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
23+
24+
"""
25+
# Author: Rémi Flamary <[email protected]>
26+
#
27+
# License: MIT License
28+
29+
# sphinx_gallery_thumbnail_number = 3
30+
31+
from sklearn.manifold import MDS
32+
import numpy as np
33+
import matplotlib.pylab as pl
34+
import torch
35+
36+
import ot
37+
from ot.gromov import gromov_wasserstein2
38+
39+
# %%
40+
# Graph generation
41+
# ---------------
42+
43+
rng = np.random.RandomState(42)
44+
45+
46+
def get_sbm(n, nc, ratio, P):
47+
nbpc = np.round(n * ratio).astype(int)
48+
n = np.sum(nbpc)
49+
C = np.zeros((n, n))
50+
for c1 in range(nc):
51+
for c2 in range(c1 + 1):
52+
if c1 == c2:
53+
for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
54+
for j in range(np.sum(nbpc[:c2]), i):
55+
if rng.rand() <= P[c1, c2]:
56+
C[i, j] = 1
57+
else:
58+
for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
59+
for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])):
60+
if rng.rand() <= P[c1, c2]:
61+
C[i, j] = 1
62+
63+
return C + C.T
64+
65+
66+
n = 100
67+
nc = 3
68+
ratio = np.array([.5, .3, .2])
69+
P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3)))
70+
C1 = get_sbm(n, nc, ratio, P)
71+
72+
# get 2d position for nodes
73+
x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1)
74+
75+
76+
def plot_graph(x, C, color='C0', s=None):
77+
for j in range(C.shape[0]):
78+
for i in range(j):
79+
if C[i, j] > 0:
80+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
81+
pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
82+
83+
84+
pl.figure(1, (10, 5))
85+
pl.clf()
86+
pl.subplot(1, 2, 1)
87+
plot_graph(x1, C1, color='C0')
88+
pl.title("SBM Graph")
89+
pl.axis("off")
90+
pl.subplot(1, 2, 2)
91+
pl.imshow(C1, interpolation='nearest')
92+
pl.title("Adjacency matrix")
93+
pl.axis("off")
94+
95+
96+
# %%
97+
# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1
98+
# ------------------------------------------------
99+
# The adajacency matrix C1 is block diagonal with 3 blocks. We want to
100+
# optimize the weights of a simple template C0=eye(3) and see if we can
101+
# recover the proportion of classes from the SBM (up to a permutation).
102+
103+
C0 = np.eye(3)
104+
105+
106+
def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):
107+
""" solve min_a GW(C1,C2,a, a2) by gradient descent"""
108+
109+
# use pyTorch for our data
110+
C1_torch = torch.tensor(C1)
111+
C2_torch = torch.tensor(C2)
112+
113+
a0 = rng.rand(C1.shape[0]) # random_init
114+
a0 /= a0.sum() # on simplex
115+
a1_torch = torch.tensor(a0).requires_grad_(True)
116+
a2_torch = torch.tensor(a2)
117+
118+
loss_iter = []
119+
120+
for i in range(nb_iter_max):
121+
122+
loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
123+
124+
loss_iter.append(loss.clone().detach().cpu().numpy())
125+
loss.backward()
126+
127+
#print("{:03d} | {}".format(i, loss_iter[-1]))
128+
129+
# performs a step of projected gradient descent
130+
with torch.no_grad():
131+
grad = a1_torch.grad
132+
a1_torch -= grad * lr # step
133+
a1_torch.grad.zero_()
134+
a1_torch.data = ot.utils.proj_simplex(a1_torch)
135+
136+
a1 = a1_torch.clone().detach().cpu().numpy()
137+
138+
return a1, loss_iter
139+
140+
141+
a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2)
142+
143+
pl.figure(2)
144+
pl.plot(loss_iter0)
145+
pl.title("Loss along iterations")
146+
147+
print("Estimated weights : ", a0_est)
148+
print("True proportions : ", ratio)
149+
150+
151+
# %%
152+
# It is clear that the optimization has converged and that we recover the
153+
# ratio of the different classes in the SBM graph up to a permutation.
154+
155+
156+
# %%
157+
# Community clustering with uniform and estimated weights
158+
# --------------------------------------------
159+
# The GW OT plan can be used to perform a clustering of the nodes of a graph
160+
# when computing the GW with a simple template like C0 by labeling nodes in
161+
# the original graph using by the index of the noe in the template receiving
162+
# the most mass.
163+
#
164+
# We show here the result of such a clustering when using uniform weights on
165+
# the template C0 and when using the optimal weights previously estimated.
166+
167+
168+
T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3))
169+
label_unif = T_unif.argmax(1)
170+
171+
T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est)
172+
label_est = T_est.argmax(1)
173+
174+
pl.figure(3, (10, 5))
175+
pl.clf()
176+
pl.subplot(1, 2, 1)
177+
plot_graph(x1, C1, color=label_unif)
178+
pl.title("Graph clustering unif. weights")
179+
pl.axis("off")
180+
pl.subplot(1, 2, 2)
181+
plot_graph(x1, C1, color=label_est)
182+
pl.title("Graph clustering est. weights")
183+
pl.axis("off")
184+
185+
186+
# %%
187+
# Graph compression with GW
188+
# -------------------------
189+
190+
# Now we optimize both the weights and structure of a small graph that
191+
# minimize the GW distance wrt our data graph. This can be seen as graph
192+
# compression but can also recover important properties of an SBM such
193+
# as its class proportion but also its matrix of probability of links between
194+
# classes
195+
196+
197+
def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):
198+
""" solve min_a GW(C1,C2,a, a2) by gradient descent"""
199+
200+
# use pyTorch for our data
201+
202+
C2_torch = torch.tensor(C2)
203+
a2_torch = torch.tensor(a2)
204+
205+
a0 = rng.rand(nb_nodes) # random_init
206+
a0 /= a0.sum() # on simplex
207+
a1_torch = torch.tensor(a0).requires_grad_(True)
208+
C0 = np.eye(nb_nodes)
209+
C1_torch = torch.tensor(C0).requires_grad_(True)
210+
211+
loss_iter = []
212+
213+
for i in range(nb_iter_max):
214+
215+
loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
216+
217+
loss_iter.append(loss.clone().detach().cpu().numpy())
218+
loss.backward()
219+
220+
#print("{:03d} | {}".format(i, loss_iter[-1]))
221+
222+
# performs a step of projected gradient descent
223+
with torch.no_grad():
224+
grad = a1_torch.grad
225+
a1_torch -= grad * lr # step
226+
a1_torch.grad.zero_()
227+
a1_torch.data = ot.utils.proj_simplex(a1_torch)
228+
229+
grad = C1_torch.grad
230+
C1_torch -= grad * lr # step
231+
C1_torch.grad.zero_()
232+
C1_torch.data = torch.clamp(C1_torch, 0, 1)
233+
234+
a1 = a1_torch.clone().detach().cpu().numpy()
235+
C1 = C1_torch.clone().detach().cpu().numpy()
236+
237+
return a1, C1, loss_iter
238+
239+
240+
nb_nodes = 3
241+
a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n),
242+
nb_iter_max=100, lr=5e-2)
243+
244+
pl.figure(4)
245+
pl.plot(loss_iter2)
246+
pl.title("Loss along iterations")
247+
248+
249+
print("Estimated weights : ", a0_est2)
250+
print("True proportions : ", ratio)
251+
252+
pl.figure(6, (10, 3.5))
253+
pl.clf()
254+
pl.subplot(1, 2, 1)
255+
pl.imshow(P, vmin=0, vmax=1)
256+
pl.title('True SBM P matrix')
257+
pl.subplot(1, 2, 2)
258+
pl.imshow(C0_est2, vmin=0, vmax=1)
259+
pl.title('Estimated C0 matrix')
260+
pl.colorbar()

ot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
sinkhorn_unbalanced2)
4444
from .da import sinkhorn_lpl1_mm
4545
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
46+
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
47+
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
4648

4749
# utils functions
4850
from .utils import dist, unif, tic, toc, toq

0 commit comments

Comments
 (0)