Skip to content

Commit 1a6c790

Browse files
authored
[MRG] Translation Invariant Sinkhorn for Unbalanced OT (#676)
* uot sinkhorn translation invariant * correct log sinkhorn_ti * fix log sinkhorn_ti * test infinite reg sinkhorn unbalanced * fix doc translation invariant sinkhorn * fix pep8 * avoid nan in loop ti sinkhorn * Add test multiple hists, log False * up test multiple input with reg_type='entropy' * up test multiple inputs * correct number ref * correct number ref * jax vmap searchsorted * jax vmap searchsorted
1 parent 791137b commit 1a6c790

File tree

7 files changed

+476
-18
lines changed

7 files changed

+476
-18
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
381381
[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on
382382
Artificial Intelligence.
383383

384-
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
384+
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
385+
386+
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Restructured `ot.unbalanced` module (PR #658)
1414
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
1515
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)
16+
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
1617

1718
#### Closed issues
1819
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===============================================================
4+
Translation Invariant Sinkhorn for Unbalanced Optimal Transport
5+
===============================================================
6+
7+
This examples illustrates the better convergence of the translation
8+
invariance Sinkhorn algorithm proposed in [73] compared to the classical
9+
Sinkhorn algorithm.
10+
11+
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
12+
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
13+
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
14+
15+
"""
16+
17+
# Author: Clément Bonet <[email protected]>
18+
# License: MIT License
19+
20+
import numpy as np
21+
import matplotlib.pylab as pl
22+
import ot
23+
24+
##############################################################################
25+
# Setting parameters
26+
# -------------
27+
28+
# %% parameters
29+
30+
n_iter = 50 # nb iters
31+
n = 40 # nb samples
32+
33+
num_iter_max = 100
34+
n_noise = 10
35+
36+
reg = 0.005
37+
reg_m_kl = 0.05
38+
39+
mu_s = np.array([-1, -1])
40+
cov_s = np.array([[1, 0], [0, 1]])
41+
42+
mu_t = np.array([4, 4])
43+
cov_t = np.array([[1, -.8], [-.8, 1]])
44+
45+
46+
##############################################################################
47+
# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn
48+
# -----------
49+
50+
err_sinkhorn_uot = np.empty((n_iter, num_iter_max))
51+
err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))
52+
53+
54+
for seed in range(n_iter):
55+
np.random.seed(seed)
56+
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
57+
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
58+
59+
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
60+
xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
61+
62+
n = n + n_noise
63+
64+
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
65+
66+
# loss matrix
67+
M = ot.dist(xs, xt)
68+
M /= M.max()
69+
70+
entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0)
71+
entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl",
72+
method="sinkhorn_translation_invariant", log=True,
73+
numItermax=num_iter_max, stopThr=0)
74+
75+
err_sinkhorn_uot[seed] = log_uot["err"]
76+
err_sinkhorn_uot_ti[seed] = log_uot_ti["err"]
77+
78+
##############################################################################
79+
# Plot the results
80+
# ----------------
81+
82+
mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)
83+
std_sinkh = np.std(err_sinkhorn_uot, axis=0)
84+
85+
mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)
86+
std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)
87+
88+
absc = list(range(num_iter_max))
89+
90+
pl.plot(absc, mean_sinkh, label="Sinkhorn")
91+
pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)
92+
93+
pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn")
94+
pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5)
95+
96+
pl.yscale("log")
97+
pl.legend()
98+
pl.xlabel("Number of Iterations")
99+
pl.ylabel(r"$\|u-v\|_\infty$")
100+
pl.grid(True)
101+
pl.show()

ot/backend.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1590,9 +1590,7 @@ def searchsorted(self, a, v, side='left'):
15901590
if a.ndim == 1:
15911591
return jnp.searchsorted(a, v, side)
15921592
else:
1593-
# this is a not very efficient way to make jax numpy
1594-
# searchsorted work on 2d arrays
1595-
return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
1593+
return jax.vmap(lambda b, u: jnp.searchsorted(b, u, side))(a, v)
15961594

15971595
def flip(self, a, axis=None):
15981596
return jnp.flip(a, axis)

ot/unbalanced/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._sinkhorn import (sinkhorn_knopp_unbalanced,
1313
sinkhorn_unbalanced,
1414
sinkhorn_stabilized_unbalanced,
15+
sinkhorn_unbalanced_translation_invariant,
1516
sinkhorn_unbalanced2,
1617
barycenter_unbalanced_sinkhorn,
1718
barycenter_unbalanced_stabilized,
@@ -22,6 +23,7 @@
2223
from ._lbfgs import (lbfgsb_unbalanced, lbfgsb_unbalanced2)
2324

2425
__all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced',
25-
'sinkhorn_unbalanced2', 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized',
26+
'sinkhorn_unbalanced_translation_invariant', 'sinkhorn_unbalanced2',
27+
'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized',
2628
'barycenter_unbalanced', 'mm_unbalanced', 'mm_unbalanced2', '_get_loss_unbalanced',
2729
'lbfgsb_unbalanced', 'lbfgsb_unbalanced2']

0 commit comments

Comments
 (0)