Skip to content

Commit 897026e

Browse files
6Ulmrflamaryagramfort
authored
[MRG] CO-Optimal Transport solver (#447)
* Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Alexandre Gramfort <[email protected]>
1 parent b9ed7b1 commit 897026e

File tree

7 files changed

+1052
-7
lines changed

7 files changed

+1052
-7
lines changed

README.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,15 @@ You can also post bug reports and feature requests in Github issues. Make sure t
276276

277277
[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
278278

279-
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
280-
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
281-
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
279+
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
280+
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
281+
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
282282
Machine Learning (pp. 4104-4113). PMLR.
283283

284284
[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
285285
Conference on Machine Learning, PMLR 119:4692-4701, 2020
286286

287-
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
287+
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
288288
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
289289

290290
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
@@ -305,4 +305,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
305305

306306
[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.
307307

308-
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
308+
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
309+
310+
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.

RELEASES.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
- New API for OT solver using function `ot.solve` (PR #388)
1717
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
1818
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
19-
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
20-
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
19+
- Added parameters method in `ot.da.SinkhornTransport` (PR #440)
20+
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
21+
Pymanopt (PR #443)
22+
- Added CO-Optimal Transport solver + examples (PR # 447)
2123
- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448)
2224

2325
#### Closed issues

docs/source/all.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ API and modules
1616

1717
backend
1818
bregman
19+
coot
1920
da
2021
datasets
2122
dr

examples/others/plot_COOT.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
===================================================
4+
Row and column alignments with CO-Optimal Transport
5+
===================================================
6+
7+
This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
8+
CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
9+
matrices, and to align their rows and columns. In this example, we consider two
10+
random matrices :math:`X_1` and :math:`X_2` defined by
11+
:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
12+
and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.
13+
14+
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
15+
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
16+
Advances in Neural Information Processing Systems, 33.
17+
"""
18+
19+
# Author: Remi Flamary <[email protected]>
20+
# Quang Huy Tran <[email protected]>
21+
# License: MIT License
22+
23+
from matplotlib.patches import ConnectionPatch
24+
import matplotlib.pylab as pl
25+
import numpy as np
26+
from ot.coot import co_optimal_transport as coot
27+
from ot.coot import co_optimal_transport2 as coot2
28+
29+
# %%
30+
# Generating two random matrices
31+
32+
n1 = 20
33+
n2 = 10
34+
d1 = 16
35+
d2 = 8
36+
sigma = 0.2
37+
38+
X1 = (
39+
np.cos(np.arange(n1) * np.pi / n1)[:, None] +
40+
np.cos(np.arange(d1) * np.pi / d1)[None, :] +
41+
sigma * np.random.randn(n1, d1)
42+
)
43+
X2 = (
44+
np.cos(np.arange(n2) * np.pi / n2)[:, None] +
45+
np.cos(np.arange(d2) * np.pi / d2)[None, :] +
46+
sigma * np.random.randn(n2, d2)
47+
)
48+
49+
# %%
50+
# Visualizing the matrices
51+
52+
pl.figure(1, (8, 5))
53+
pl.subplot(1, 2, 1)
54+
pl.imshow(X1)
55+
pl.title('$X_1$')
56+
57+
pl.subplot(1, 2, 2)
58+
pl.imshow(X2)
59+
pl.title("$X_2$")
60+
61+
pl.tight_layout()
62+
63+
# %%
64+
# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance
65+
66+
pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
67+
coot_distance = coot2(X1, X2)
68+
print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))
69+
70+
fig = pl.figure(4, (9, 7))
71+
pl.clf()
72+
73+
ax1 = pl.subplot(2, 2, 3)
74+
pl.imshow(X1)
75+
pl.xlabel('$X_1$')
76+
77+
ax2 = pl.subplot(2, 2, 2)
78+
ax2.yaxis.tick_right()
79+
pl.imshow(np.transpose(X2))
80+
pl.title("Transpose($X_2$)")
81+
ax2.xaxis.tick_top()
82+
83+
for i in range(n1):
84+
j = np.argmax(pi_sample[i, :])
85+
xyA = (d1 - .5, i)
86+
xyB = (j, d2 - .5)
87+
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
88+
coordsB=ax2.transData, color="black")
89+
fig.add_artist(con)
90+
91+
for i in range(d1):
92+
j = np.argmax(pi_feature[i, :])
93+
xyA = (i, -.5)
94+
xyB = (-.5, j)
95+
con = ConnectionPatch(
96+
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
97+
fig.add_artist(con)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
===============================================================
4+
Learning sample marginal distribution with CO-Optimal Transport
5+
===============================================================
6+
7+
In this example, we illustrate how to estimate the sample marginal distribution which minimizes
8+
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
9+
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
10+
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
11+
12+
.. math::
13+
\min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
14+
15+
where :math:`\Delta` is the probability simplex. This minimization is done with a
16+
simple projected gradient descent in PyTorch. We use the automatic backend of POT that
17+
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
18+
with differentiable losses.
19+
20+
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
21+
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
22+
Advances in Neural Information Processing Systems, 33.
23+
"""
24+
25+
# Author: Remi Flamary <[email protected]>
26+
# Quang Huy Tran <[email protected]>
27+
# License: MIT License
28+
29+
from matplotlib.patches import ConnectionPatch
30+
import torch
31+
import numpy as np
32+
33+
import matplotlib.pyplot as pl
34+
import ot
35+
36+
from ot.coot import co_optimal_transport as coot
37+
from ot.coot import co_optimal_transport2 as coot2
38+
39+
40+
# %%
41+
# Generate data
42+
# -------------
43+
# The source and clean target matrices are generated by
44+
# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
45+
# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
46+
# The target matrix is then contaminated by adding 5 row outliers.
47+
# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
48+
# i.e. their weights should be zero.
49+
50+
np.random.seed(182)
51+
52+
n1, d1 = 20, 16
53+
n2, d2 = 10, 8
54+
n = 15
55+
56+
X = (
57+
torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
58+
torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
59+
)
60+
61+
# Generate clean target data mixed with outliers
62+
Y_noisy = torch.randn((n, d2)) * 10.0
63+
Y_noisy[:n2, :] = (
64+
torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
65+
torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
66+
)
67+
Y = Y_noisy[:n2, :]
68+
69+
X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()
70+
71+
fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
72+
axes[0].imshow(X, vmin=-2, vmax=2)
73+
axes[0].set_title('$X$')
74+
75+
axes[1].imshow(Y, vmin=-2, vmax=2)
76+
axes[1].set_title('Clean $Y$')
77+
78+
axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
79+
axes[2].set_title('Noisy $Y$')
80+
81+
pl.tight_layout()
82+
83+
# %%
84+
# Optimize the COOT distance with respect to the sample marginal distribution
85+
# ---------------------------------------------------------------------------
86+
87+
losses = []
88+
lr = 1e-3
89+
niter = 1000
90+
91+
b = torch.tensor(ot.unif(n), requires_grad=True)
92+
93+
for i in range(niter):
94+
95+
loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
96+
losses.append(float(loss))
97+
98+
loss.backward()
99+
100+
with torch.no_grad():
101+
b -= lr * b.grad # gradient step
102+
b[:] = ot.utils.proj_simplex(b) # projection on the simplex
103+
104+
b.grad.zero_()
105+
106+
# Estimated sample marginal distribution and training loss curve
107+
pl.plot(losses[10:])
108+
pl.title('CO-Optimal Transport distance')
109+
110+
print(f"Marginal distribution = {b.detach().numpy()}")
111+
112+
# %%
113+
# Visualizing the row and column alignments with the estimated sample marginal distribution
114+
# -----------------------------------------------------------------------------------------
115+
#
116+
# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.
117+
118+
X, Y_noisy = X.numpy(), Y_noisy.numpy()
119+
b = b.detach().numpy()
120+
121+
pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)
122+
123+
fig = pl.figure(4, (9, 7))
124+
pl.clf()
125+
126+
ax1 = pl.subplot(2, 2, 3)
127+
pl.imshow(X, vmin=-2, vmax=2)
128+
pl.xlabel('$X$')
129+
130+
ax2 = pl.subplot(2, 2, 2)
131+
ax2.yaxis.tick_right()
132+
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
133+
pl.title("Transpose(Noisy $Y$)")
134+
ax2.xaxis.tick_top()
135+
136+
for i in range(n1):
137+
j = np.argmax(pi_sample[i, :])
138+
xyA = (d1 - .5, i)
139+
xyB = (j, d2 - .5)
140+
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
141+
coordsB=ax2.transData, color="black")
142+
fig.add_artist(con)
143+
144+
for i in range(d1):
145+
j = np.argmax(pi_feature[i, :])
146+
xyA = (i, -.5)
147+
xyB = (-.5, j)
148+
con = ConnectionPatch(
149+
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
150+
fig.add_artist(con)

0 commit comments

Comments
 (0)