Skip to content

Commit 80e3c23

Browse files
authored
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
1 parent 97feeb3 commit 80e3c23

16 files changed

+1852
-23
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ The contributors to this library are:
4141
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
4242
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
4343
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
44+
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
4445

4546
## Acknowledgments
4647

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples):
3939
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
4040
formulations).
4141
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
42+
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
43+
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
4244
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
4345
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
4446

@@ -292,4 +294,10 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
292294

293295
[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.
294296

295-
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
297+
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
298+
299+
[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
300+
301+
[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
302+
303+
[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations.

RELEASES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
#### New features
66

7+
- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434)
8+
- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434)
9+
- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434)
10+
- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434)
711
- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
812
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
913
- Added Free Support Sinkhorn Barycenter + example (PR #387)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
================================================
4+
Spherical Sliced-Wasserstein Embedding on Sphere
5+
================================================
6+
7+
Here, we aim at transforming samples into a uniform
8+
distribution on the sphere by minimizing SSW:
9+
10+
.. math::
11+
\min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
12+
13+
where :math:`\nu=\mathrm{Unif}(S^1)`.
14+
15+
"""
16+
17+
# Author: Clément Bonet <[email protected]>
18+
#
19+
# License: MIT License
20+
21+
# sphinx_gallery_thumbnail_number = 3
22+
23+
import numpy as np
24+
import matplotlib.pyplot as pl
25+
import matplotlib.animation as animation
26+
import torch
27+
import torch.nn.functional as F
28+
29+
import ot
30+
31+
32+
# %%
33+
# Data generation
34+
# ---------------
35+
36+
torch.manual_seed(1)
37+
38+
N = 1000
39+
x0 = torch.rand(N, 3)
40+
x0 = F.normalize(x0, dim=-1)
41+
42+
43+
# %%
44+
# Plot data
45+
# ---------
46+
47+
def plot_sphere(ax):
48+
xlist = np.linspace(-1.0, 1.0, 50)
49+
ylist = np.linspace(-1.0, 1.0, 50)
50+
r = np.linspace(1.0, 1.0, 50)
51+
X, Y = np.meshgrid(xlist, ylist)
52+
53+
Z = np.sqrt(r**2 - X**2 - Y**2)
54+
55+
ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3)
56+
ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half
57+
58+
59+
# plot the distributions
60+
pl.figure(1)
61+
ax = pl.axes(projection='3d')
62+
plot_sphere(ax)
63+
ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5)
64+
ax.set_title('Data distribution')
65+
ax.legend()
66+
67+
68+
# %%
69+
# Gradient descent
70+
# ----------------
71+
72+
x = x0.clone()
73+
x.requires_grad_(True)
74+
75+
n_iter = 500
76+
lr = 100
77+
78+
losses = []
79+
xvisu = torch.zeros(n_iter, N, 3)
80+
81+
for i in range(n_iter):
82+
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
83+
grad_x = torch.autograd.grad(sw, x)[0]
84+
85+
x = x - lr * grad_x
86+
x = F.normalize(x, p=2, dim=1)
87+
88+
losses.append(sw.item())
89+
xvisu[i, :, :] = x.detach().clone()
90+
91+
if i % 100 == 0:
92+
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
93+
94+
pl.figure(1)
95+
pl.semilogy(losses)
96+
pl.grid()
97+
pl.title('SSW')
98+
pl.xlabel("Iterations")
99+
100+
101+
# %%
102+
# Plot trajectories of generated samples along iterations
103+
# -------------------------------------------------------
104+
105+
ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
106+
107+
fig = pl.figure(3, (10, 10))
108+
for i in range(9):
109+
# pl.subplot(3, 3, i + 1)
110+
# ax = pl.axes(projection='3d')
111+
ax = fig.add_subplot(3, 3, i + 1, projection='3d')
112+
plot_sphere(ax)
113+
ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5)
114+
ax.set_title('Iter. {}'.format(ivisu[i]))
115+
#ax.axis("off")
116+
if i == 0:
117+
ax.legend()
118+
119+
120+
# %%
121+
# Animate trajectories of generated samples along iteration
122+
# -------------------------------------------------------
123+
124+
pl.figure(4, (8, 8))
125+
126+
127+
def _update_plot(i):
128+
i = 3 * i
129+
pl.clf()
130+
ax = pl.axes(projection='3d')
131+
plot_sphere(ax)
132+
ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5)
133+
ax.axis("off")
134+
ax.set_xlim((-1.5, 1.5))
135+
ax.set_ylim((-1.5, 1.5))
136+
ax.set_title('Iter. {}'.format(i))
137+
return 1
138+
139+
140+
print(xvisu.shape)
141+
142+
i = 0
143+
ax = pl.axes(projection='3d')
144+
plot_sphere(ax)
145+
ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5)
146+
ax.axis("off")
147+
ax.set_xlim((-1.5, 1.5))
148+
ax.set_ylim((-1.5, 1.5))
149+
ax.set_title('Iter. {}'.format(ivisu[i]))
150+
151+
152+
ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
153+
# %%
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=========================
4+
OT distance on the Circle
5+
=========================
6+
7+
Shows how to compute the Wasserstein distance on the circle
8+
9+
10+
"""
11+
12+
# Author: Clément Bonet <[email protected]>
13+
#
14+
# License: MIT License
15+
16+
# sphinx_gallery_thumbnail_number = 2
17+
18+
import numpy as np
19+
import matplotlib.pylab as pl
20+
import ot
21+
22+
from scipy.special import iv
23+
24+
##############################################################################
25+
# Plot data
26+
# ---------
27+
28+
#%% plot the distributions
29+
30+
31+
def pdf_von_Mises(theta, mu, kappa):
32+
pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
33+
return pdf
34+
35+
36+
t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)
37+
38+
mu1 = 1
39+
kappa1 = 20
40+
41+
mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)
42+
43+
44+
pdf1 = pdf_von_Mises(t, mu1, kappa1)
45+
46+
47+
pl.figure(1)
48+
for k, mu in enumerate(mu_targets):
49+
pdf_t = pdf_von_Mises(t, mu, kappa1)
50+
if k == 0:
51+
label = "Source distributions"
52+
else:
53+
label = None
54+
pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label)
55+
56+
pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
57+
pl.legend()
58+
59+
mu2 = 0
60+
kappa2 = kappa1
61+
62+
x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
63+
x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi
64+
65+
angles = np.linspace(0, 2 * np.pi, 150)
66+
67+
pl.figure(2)
68+
pl.plot(np.cos(angles), np.sin(angles), c="k")
69+
pl.xlim(-1.25, 1.25)
70+
pl.ylim(-1.25, 1.25)
71+
pl.scatter(np.cos(x1), np.sin(x1), c="b")
72+
pl.scatter(np.cos(x2), np.sin(x2), c="r")
73+
74+
#########################################################################################
75+
# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle
76+
# ---------------------------------------------------------------------------------------
77+
# This examples illustrates the periodicity of the Wasserstein distance on the circle.
78+
# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}`
79+
# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution
80+
# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`.
81+
# The Wasserstein distance on the circle takes into account the periodicity
82+
# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the
83+
# Euclidean version.
84+
85+
#%% Compute and plot distributions
86+
87+
mu_targets = np.linspace(0, 2 * np.pi, 200)
88+
xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi
89+
90+
n_try = 5
91+
92+
xts = np.zeros((n_try, 200, 500))
93+
for i in range(n_try):
94+
for k, mu in enumerate(mu_targets):
95+
# np.random.vonmises deals with data on [-pi, pi[
96+
xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
97+
xts[i, k] = xt
98+
99+
# Put data on S^1=[0,1[
100+
xts2 = xts / (2 * np.pi)
101+
xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)
102+
103+
L_w2_circle = np.zeros((n_try, 200))
104+
L_w2 = np.zeros((n_try, 200))
105+
106+
for i in range(n_try):
107+
w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
108+
w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
109+
110+
L_w2_circle[i] = w2_circle
111+
L_w2[i] = w2
112+
113+
m_w2_circle = np.mean(L_w2_circle, axis=0)
114+
std_w2_circle = np.std(L_w2_circle, axis=0)
115+
116+
m_w2 = np.mean(L_w2, axis=0)
117+
std_w2 = np.std(L_w2, axis=0)
118+
119+
pl.figure(1)
120+
pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
121+
pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5)
122+
pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
123+
pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5)
124+
pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$")
125+
pl.legend()
126+
pl.xlabel(r"$\mu_{\mathrm{source}}$")
127+
pl.show()
128+
129+
130+
########################################################################
131+
# Wasserstein distance between von Mises and uniform for different kappa
132+
# ----------------------------------------------------------------------
133+
# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`.
134+
135+
#%% Compute Wasserstein between Von Mises and uniform
136+
137+
kappas = np.logspace(-5, 2, 100)
138+
n_try = 20
139+
140+
xts = np.zeros((n_try, 100, 500))
141+
for i in range(n_try):
142+
for k, kappa in enumerate(kappas):
143+
# np.random.vonmises deals with data on [-pi, pi[
144+
xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
145+
xts[i, k] = xt / (2 * np.pi)
146+
147+
L_w2 = np.zeros((n_try, 100))
148+
for i in range(n_try):
149+
L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
150+
151+
m_w2 = np.mean(L_w2, axis=0)
152+
std_w2 = np.std(L_w2, axis=0)
153+
154+
pl.figure(1)
155+
pl.plot(kappas, m_w2)
156+
pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
157+
pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
158+
pl.xlabel(r"$\kappa$")
159+
pl.show()
160+
161+
# %%

0 commit comments

Comments
 (0)