Skip to content

Commit 2364d56

Browse files
authored
Merge pull request #87 from hichamjanati/unbalanced-ot
[MRG] Add Unbalanced KL Wasserstein distance + barycenter
2 parents 5a6b226 + c9df246 commit 2364d56

File tree

8 files changed

+919
-3
lines changed

8 files changed

+919
-3
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ It provides the following solvers:
2727
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
2828
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
2929
* Non regularized free support Wasserstein barycenters [20].
30+
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].
3031

3132
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
3233

@@ -165,6 +166,7 @@ The contributors to this library are:
165166
* [Kilian Fatras](https://kilianfatras.github.io/)
166167
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
167168
* [Vayer Titouan](https://tvayer.github.io/)
169+
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
168170

169171
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
170172

@@ -236,3 +238,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
236238
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
237239

238240
[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).
241+
242+
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).

docs/source/all.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ot.da
4343

4444
.. automodule:: ot.da
4545
:members:
46-
46+
4747
ot.gpu
4848
--------
4949

@@ -80,3 +80,9 @@ ot.stochastic
8080

8181
.. automodule:: ot.stochastic
8282
:members:
83+
84+
ot.unbalanced
85+
-------------
86+
87+
.. automodule:: ot.unbalanced
88+
:members:

examples/plot_UOT_1D.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===============================
4+
1D Unbalanced optimal transport
5+
===============================
6+
7+
This example illustrates the computation of Unbalanced Optimal transport
8+
using a Kullback-Leibler relaxation.
9+
"""
10+
11+
# Author: Hicham Janati <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
import numpy as np
16+
import matplotlib.pylab as pl
17+
import ot
18+
import ot.plot
19+
from ot.datasets import make_1D_gauss as gauss
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
25+
26+
#%% parameters
27+
28+
n = 100 # nb bins
29+
30+
# bin positions
31+
x = np.arange(n, dtype=np.float64)
32+
33+
# Gaussian distributions
34+
a = gauss(n, m=20, s=5) # m= mean, s= std
35+
b = gauss(n, m=60, s=10)
36+
37+
# make distributions unbalanced
38+
b *= 5.
39+
40+
# loss matrix
41+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
42+
M /= M.max()
43+
44+
45+
##############################################################################
46+
# Plot distributions and loss matrix
47+
# ----------------------------------
48+
49+
#%% plot the distributions
50+
51+
pl.figure(1, figsize=(6.4, 3))
52+
pl.plot(x, a, 'b', label='Source distribution')
53+
pl.plot(x, b, 'r', label='Target distribution')
54+
pl.legend()
55+
56+
# plot distributions and loss matrix
57+
58+
pl.figure(2, figsize=(5, 5))
59+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
60+
61+
62+
##############################################################################
63+
# Solve Unbalanced Sinkhorn
64+
# --------------
65+
66+
67+
# Sinkhorn
68+
69+
epsilon = 0.1 # entropy parameter
70+
alpha = 1. # Unbalanced KL relaxation parameter
71+
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
72+
73+
pl.figure(4, figsize=(5, 5))
74+
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
75+
76+
pl.show()

examples/plot_UOT_barycenter_1D.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===========================================================
4+
1D Wasserstein barycenter demo for Unbalanced distributions
5+
===========================================================
6+
7+
This example illustrates the computation of regularized Wassersyein Barycenter
8+
as proposed in [10] for Unbalanced inputs.
9+
10+
11+
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
12+
13+
"""
14+
15+
# Author: Hicham Janati <[email protected]>
16+
#
17+
# License: MIT License
18+
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
# necessary for 3d plot even if not used
23+
from mpl_toolkits.mplot3d import Axes3D # noqa
24+
from matplotlib.collections import PolyCollection
25+
26+
##############################################################################
27+
# Generate data
28+
# -------------
29+
30+
# parameters
31+
32+
n = 100 # nb bins
33+
34+
# bin positions
35+
x = np.arange(n, dtype=np.float64)
36+
37+
# Gaussian distributions
38+
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
39+
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
40+
41+
# make unbalanced dists
42+
a2 *= 3.
43+
44+
# creating matrix A containing all distributions
45+
A = np.vstack((a1, a2)).T
46+
n_distributions = A.shape[1]
47+
48+
# loss matrix + normalization
49+
M = ot.utils.dist0(n)
50+
M /= M.max()
51+
52+
##############################################################################
53+
# Plot data
54+
# ---------
55+
56+
# plot the distributions
57+
58+
pl.figure(1, figsize=(6.4, 3))
59+
for i in range(n_distributions):
60+
pl.plot(x, A[:, i])
61+
pl.title('Distributions')
62+
pl.tight_layout()
63+
64+
##############################################################################
65+
# Barycenter computation
66+
# ----------------------
67+
68+
# non weighted barycenter computation
69+
70+
weight = 0.5 # 0<=weight<=1
71+
weights = np.array([1 - weight, weight])
72+
73+
# l2bary
74+
bary_l2 = A.dot(weights)
75+
76+
# wasserstein
77+
reg = 1e-3
78+
alpha = 1.
79+
80+
bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
81+
82+
pl.figure(2)
83+
pl.clf()
84+
pl.subplot(2, 1, 1)
85+
for i in range(n_distributions):
86+
pl.plot(x, A[:, i])
87+
pl.title('Distributions')
88+
89+
pl.subplot(2, 1, 2)
90+
pl.plot(x, bary_l2, 'r', label='l2')
91+
pl.plot(x, bary_wass, 'g', label='Wasserstein')
92+
pl.legend()
93+
pl.title('Barycenters')
94+
pl.tight_layout()
95+
96+
##############################################################################
97+
# Barycentric interpolation
98+
# -------------------------
99+
100+
# barycenter interpolation
101+
102+
n_weight = 11
103+
weight_list = np.linspace(0, 1, n_weight)
104+
105+
106+
B_l2 = np.zeros((n, n_weight))
107+
108+
B_wass = np.copy(B_l2)
109+
110+
for i in range(0, n_weight):
111+
weight = weight_list[i]
112+
weights = np.array([1 - weight, weight])
113+
B_l2[:, i] = A.dot(weights)
114+
B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
115+
116+
117+
# plot interpolation
118+
119+
pl.figure(3)
120+
121+
cmap = pl.cm.get_cmap('viridis')
122+
verts = []
123+
zs = weight_list
124+
for i, z in enumerate(zs):
125+
ys = B_l2[:, i]
126+
verts.append(list(zip(x, ys)))
127+
128+
ax = pl.gcf().gca(projection='3d')
129+
130+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
131+
poly.set_alpha(0.7)
132+
ax.add_collection3d(poly, zs=zs, zdir='y')
133+
ax.set_xlabel('x')
134+
ax.set_xlim3d(0, n)
135+
ax.set_ylabel(r'$\alpha$')
136+
ax.set_ylim3d(0, 1)
137+
ax.set_zlabel('')
138+
ax.set_zlim3d(0, B_l2.max() * 1.01)
139+
pl.title('Barycenter interpolation with l2')
140+
pl.tight_layout()
141+
142+
pl.figure(4)
143+
cmap = pl.cm.get_cmap('viridis')
144+
verts = []
145+
zs = weight_list
146+
for i, z in enumerate(zs):
147+
ys = B_wass[:, i]
148+
verts.append(list(zip(x, ys)))
149+
150+
ax = pl.gcf().gca(projection='3d')
151+
152+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
153+
poly.set_alpha(0.7)
154+
ax.add_collection3d(poly, zs=zs, zdir='y')
155+
ax.set_xlabel('x')
156+
ax.set_xlim3d(0, n)
157+
ax.set_ylabel(r'$\alpha$')
158+
ax.set_ylim3d(0, 1)
159+
ax.set_zlabel('')
160+
ax.set_zlim3d(0, B_l2.max() * 1.01)
161+
pl.title('Barycenter interpolation with Wasserstein')
162+
pl.tight_layout()
163+
164+
pl.show()

ot/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from . import gromov
2121
from . import smooth
2222
from . import stochastic
23+
from . import unbalanced
2324

2425
# OT functions
2526
from .lp import emd, emd2
2627
from .bregman import sinkhorn, sinkhorn2, barycenter
28+
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
2729
from .da import sinkhorn_lpl1_mm
2830

2931
# utils functions
@@ -33,4 +35,5 @@
3335

3436
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
3537
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
36-
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
38+
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
39+
'sinkhorn_unbalanced', "barycenter_unbalanced"]

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def sink():
241241

242242
b = np.asarray(b, dtype=np.float64)
243243
if len(b.shape) < 2:
244-
b = b.reshape((-1, 1))
244+
b = b[:, None]
245245

246246
return sink()
247247

0 commit comments

Comments
 (0)