|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +========================= |
| 4 | +OT distance on the Circle |
| 5 | +========================= |
| 6 | +
|
| 7 | +Shows how to compute the Wasserstein distance on the circle |
| 8 | +
|
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | +# Author: Clément Bonet <[email protected]> |
| 13 | +# |
| 14 | +# License: MIT License |
| 15 | + |
| 16 | +# sphinx_gallery_thumbnail_number = 2 |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import matplotlib.pylab as pl |
| 20 | +import ot |
| 21 | + |
| 22 | +from scipy.special import iv |
| 23 | + |
| 24 | +############################################################################## |
| 25 | +# Plot data |
| 26 | +# --------- |
| 27 | + |
| 28 | +#%% plot the distributions |
| 29 | + |
| 30 | + |
| 31 | +def pdf_von_Mises(theta, mu, kappa): |
| 32 | + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) |
| 33 | + return pdf |
| 34 | + |
| 35 | + |
| 36 | +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) |
| 37 | + |
| 38 | +mu1 = 1 |
| 39 | +kappa1 = 20 |
| 40 | + |
| 41 | +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) |
| 42 | + |
| 43 | + |
| 44 | +pdf1 = pdf_von_Mises(t, mu1, kappa1) |
| 45 | + |
| 46 | + |
| 47 | +pl.figure(1) |
| 48 | +for k, mu in enumerate(mu_targets): |
| 49 | + pdf_t = pdf_von_Mises(t, mu, kappa1) |
| 50 | + if k == 0: |
| 51 | + label = "Source distributions" |
| 52 | + else: |
| 53 | + label = None |
| 54 | + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) |
| 55 | + |
| 56 | +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") |
| 57 | +pl.legend() |
| 58 | + |
| 59 | +mu2 = 0 |
| 60 | +kappa2 = kappa1 |
| 61 | + |
| 62 | +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi |
| 63 | +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi |
| 64 | + |
| 65 | +angles = np.linspace(0, 2 * np.pi, 150) |
| 66 | + |
| 67 | +pl.figure(2) |
| 68 | +pl.plot(np.cos(angles), np.sin(angles), c="k") |
| 69 | +pl.xlim(-1.25, 1.25) |
| 70 | +pl.ylim(-1.25, 1.25) |
| 71 | +pl.scatter(np.cos(x1), np.sin(x1), c="b") |
| 72 | +pl.scatter(np.cos(x2), np.sin(x2), c="r") |
| 73 | + |
| 74 | +######################################################################################### |
| 75 | +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle |
| 76 | +# --------------------------------------------------------------------------------------- |
| 77 | +# This examples illustrates the periodicity of the Wasserstein distance on the circle. |
| 78 | +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` |
| 79 | +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution |
| 80 | +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. |
| 81 | +# The Wasserstein distance on the circle takes into account the periodicity |
| 82 | +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the |
| 83 | +# Euclidean version. |
| 84 | + |
| 85 | +#%% Compute and plot distributions |
| 86 | + |
| 87 | +mu_targets = np.linspace(0, 2 * np.pi, 200) |
| 88 | +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi |
| 89 | + |
| 90 | +n_try = 5 |
| 91 | + |
| 92 | +xts = np.zeros((n_try, 200, 500)) |
| 93 | +for i in range(n_try): |
| 94 | + for k, mu in enumerate(mu_targets): |
| 95 | + # np.random.vonmises deals with data on [-pi, pi[ |
| 96 | + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi |
| 97 | + xts[i, k] = xt |
| 98 | + |
| 99 | +# Put data on S^1=[0,1[ |
| 100 | +xts2 = xts / (2 * np.pi) |
| 101 | +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) |
| 102 | + |
| 103 | +L_w2_circle = np.zeros((n_try, 200)) |
| 104 | +L_w2 = np.zeros((n_try, 200)) |
| 105 | + |
| 106 | +for i in range(n_try): |
| 107 | + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) |
| 108 | + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) |
| 109 | + |
| 110 | + L_w2_circle[i] = w2_circle |
| 111 | + L_w2[i] = w2 |
| 112 | + |
| 113 | +m_w2_circle = np.mean(L_w2_circle, axis=0) |
| 114 | +std_w2_circle = np.std(L_w2_circle, axis=0) |
| 115 | + |
| 116 | +m_w2 = np.mean(L_w2, axis=0) |
| 117 | +std_w2 = np.std(L_w2, axis=0) |
| 118 | + |
| 119 | +pl.figure(1) |
| 120 | +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") |
| 121 | +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) |
| 122 | +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") |
| 123 | +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) |
| 124 | +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") |
| 125 | +pl.legend() |
| 126 | +pl.xlabel(r"$\mu_{\mathrm{source}}$") |
| 127 | +pl.show() |
| 128 | + |
| 129 | + |
| 130 | +######################################################################## |
| 131 | +# Wasserstein distance between von Mises and uniform for different kappa |
| 132 | +# ---------------------------------------------------------------------- |
| 133 | +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. |
| 134 | + |
| 135 | +#%% Compute Wasserstein between Von Mises and uniform |
| 136 | + |
| 137 | +kappas = np.logspace(-5, 2, 100) |
| 138 | +n_try = 20 |
| 139 | + |
| 140 | +xts = np.zeros((n_try, 100, 500)) |
| 141 | +for i in range(n_try): |
| 142 | + for k, kappa in enumerate(kappas): |
| 143 | + # np.random.vonmises deals with data on [-pi, pi[ |
| 144 | + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi |
| 145 | + xts[i, k] = xt / (2 * np.pi) |
| 146 | + |
| 147 | +L_w2 = np.zeros((n_try, 100)) |
| 148 | +for i in range(n_try): |
| 149 | + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) |
| 150 | + |
| 151 | +m_w2 = np.mean(L_w2, axis=0) |
| 152 | +std_w2 = np.std(L_w2, axis=0) |
| 153 | + |
| 154 | +pl.figure(1) |
| 155 | +pl.plot(kappas, m_w2) |
| 156 | +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) |
| 157 | +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") |
| 158 | +pl.xlabel(r"$\kappa$") |
| 159 | +pl.show() |
| 160 | + |
| 161 | +# %% |
0 commit comments