|
5 | 5 | =============================================
|
6 | 6 |
|
7 | 7 |
|
8 |
| -This is a quickstart guide to the Python Optimal Transport (POT) toolbox. |
| 8 | +This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use |
| 9 | +here the new API of POT which is more flexible and allows to solve a wider range |
| 10 | +of problems with just a few functions. The old API is still available (the new |
| 11 | +one is a convenient wrapper around the old one) and we provide pointers to the |
| 12 | +old API when needed. |
9 | 13 |
|
10 | 14 | """
|
11 | 15 |
|
|
14 | 18 | # License: MIT License
|
15 | 19 | # sphinx_gallery_thumbnail_number = 1
|
16 | 20 |
|
| 21 | +# Import necessary libraries |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +import pylab as pl |
| 25 | + |
| 26 | +import ot |
| 27 | + |
| 28 | + |
17 | 29 | # %%
|
18 |
| -# Simple example |
| 30 | +# Example data |
19 | 31 | # --------------
|
20 | 32 | #
|
| 33 | +# Data generation |
| 34 | +# ~~~~~~~~~~~~~~~ |
21 | 35 |
|
22 |
| -import numpy as np # always need it |
23 |
| -import pylab as pl # for the plots |
| 36 | +# Problem size |
| 37 | +n1 = 25 |
| 38 | +n2 = 50 |
24 | 39 |
|
25 |
| -import ot |
| 40 | +# Generate random data |
| 41 | +np.random.seed(0) |
| 42 | +a = ot.utils.unif(n1) # weights of points in the source domain |
| 43 | +b = ot.utils.unif(n2) # weights of points in the target domain |
| 44 | + |
| 45 | +x1 = np.random.randn(n1, 2) |
| 46 | +x1 /= ( |
| 47 | + np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 |
| 48 | +) # project on the unit circle and scale |
| 49 | +x2 = np.random.randn(n2, 2) |
| 50 | +x2 /= ( |
| 51 | + np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 |
| 52 | +) # project on the unit circle and scale |
| 53 | + |
| 54 | +# %% |
| 55 | +# Plot data |
| 56 | +# ~~~~~~~~~ |
| 57 | + |
| 58 | +style = {"markeredgecolor": "k"} |
| 59 | + |
| 60 | +pl.figure(1, (4, 4)) |
| 61 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 62 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 63 | +pl.legend(loc=0) |
| 64 | +pl.title("Source and target distributions") |
| 65 | +pl.show() |
| 66 | + |
| 67 | +# %% |
| 68 | +# Solving exact Optimal Transport |
| 69 | +# ------------------------------- |
| 70 | +# Solve the Optimal Transport problem between the samples |
| 71 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 72 | +# |
| 73 | +# The :func:`ot.solve_sample` function can be used to solve the Optimal Transport problem |
| 74 | +# between two sets of samples. The function takes as its two first arguments the |
| 75 | +# positions of the source and target samples, and returns an :class:`ot.utils.OTResult` object. |
| 76 | + |
| 77 | +# Solve the OT problem |
| 78 | +sol = ot.solve_sample(x1, x2, a, b) |
| 79 | + |
| 80 | +# get the OT plan |
| 81 | +P = sol.plan |
| 82 | + |
| 83 | +# get the OT loss |
| 84 | +loss = sol.value |
| 85 | + |
| 86 | +# get the dual potentials |
| 87 | +alpha, beta = sol.potentials |
| 88 | + |
| 89 | +print(f"OT loss = {loss:1.3f}") |
| 90 | + |
| 91 | +# %% |
| 92 | +# We provide |
| 93 | +# the weights of the samples in the source and target domains :code:`a` and |
| 94 | +# :code:`b`. If not provided, the weights are assumed to be uniform. |
| 95 | +# |
| 96 | +# The :class:`ot.utils.OTResult` object contains the following attributes: |
| 97 | +# |
| 98 | +# - :code:`value`: the value of the OT problem |
| 99 | +# - :code:`plan`: the OT matrix |
| 100 | +# - :code:`potentials`: Dual potentials of the OT problem |
| 101 | +# - :code:`log`: log dictionary of the solver |
| 102 | +# |
| 103 | +# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where |
| 104 | +# :code:`P[i,j]` is the amount of mass |
| 105 | +# transported from :code:`x1[i]` to :code:`x2[j]`. |
| 106 | +# |
| 107 | +# The OT loss is the sum of the element-wise product of the OT matrix and the |
| 108 | +# cost matrix taken by default as the Squared Euclidean distance. |
| 109 | +# |
| 110 | + |
| 111 | +# %% |
| 112 | +# Plot the OT plan and dual potentials |
| 113 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 114 | +# |
| 115 | + |
| 116 | +from ot.plot import plot2D_samples_mat |
| 117 | + |
| 118 | +pl.figure(1, (8, 4)) |
| 119 | + |
| 120 | +pl.subplot(1, 2, 1) |
| 121 | +plot2D_samples_mat(x1, x2, P) |
| 122 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 123 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 124 | +pl.title("OT plan P loss={:.3f}".format(loss)) |
| 125 | + |
| 126 | +pl.subplot(1, 2, 2) |
| 127 | +pl.scatter(x1[:, 0], x1[:, 1], c=alpha, cmap="viridis", edgecolors="k") |
| 128 | +pl.scatter(x2[:, 0], x2[:, 1], c=beta, cmap="plasma", edgecolors="k") |
| 129 | +pl.title("Dual potentials") |
| 130 | +pl.show() |
| 131 | + |
| 132 | + |
| 133 | +pl.figure(2, (3, 1.7)) |
| 134 | +pl.imshow(P, cmap="Greys") |
| 135 | +pl.title("OT plan") |
| 136 | +pl.show() |
| 137 | + |
| 138 | +# %% |
| 139 | +# Solve the Optimal Transport problem with a custom cost matrix |
| 140 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 141 | +# |
| 142 | +# The cost matrix can be customized by passing it to the more general |
| 143 | +# :func:`ot.solve` function. The cost matrix should be a matrix of size |
| 144 | +# :code:`(n1, n2)` where :code:`C[i,j]` is the cost of transporting mass from |
| 145 | +# :code:`x1[i]` to :code:`x2[j]`. |
| 146 | +# |
| 147 | +# In this example, we use the Citybloc distance as the cost matrix. |
| 148 | + |
| 149 | +# Compute the cost matrix |
| 150 | +C = ot.dist(x1, x2, metric="cityblock") |
| 151 | + |
| 152 | +# Solve the OT problem with the custom cost matrix |
| 153 | +P_city = ot.solve(C).plan |
| 154 | + |
| 155 | +# Compute the OT loss (equivalent to ot.solve(C).value) |
| 156 | +loss_city = np.sum(P_city * C) |
| 157 | + |
| 158 | +# %% |
| 159 | +# Note that we show here how to sole the OT problem with a custom cost matrix |
| 160 | +# with the more general :func:`ot.solve` function. |
| 161 | +# But the same can be done with the :func:`ot.solve_sample` function by passing |
| 162 | +# :code:`metric='cityblock'` as argument. |
| 163 | +# |
| 164 | +# .. note:: |
| 165 | +# The examples above use the new API of POT. The old API is still available |
| 166 | +# and and OT plan and loss can be computed with the :func:`ot.emd` and |
| 167 | +# the :func:`ot.emd2` functions as below: |
| 168 | +# |
| 169 | +# .. code-block:: python |
| 170 | +# |
| 171 | +# P = ot.emd(a, b, C) |
| 172 | +# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b |
| 173 | +# |
| 174 | +# Plot the OT plan and dual potentials for other loss |
| 175 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 176 | +# |
| 177 | + |
| 178 | +pl.figure(1, (3, 3)) |
| 179 | +plot2D_samples_mat(x1, x2, P) |
| 180 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 181 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 182 | +pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) |
| 183 | + |
| 184 | +pl.figure(2, (3, 1.7)) |
| 185 | +pl.imshow(P_city, cmap="Greys") |
| 186 | +pl.title("OT plan (Citybloc)") |
| 187 | +pl.show() |
| 188 | + |
| 189 | +# %% |
| 190 | +# Sinkhorn and Regularized OT |
| 191 | +# --------------------------- |
| 192 | +# |
| 193 | +# Solve Entropic Regularized OT with Sinkhorn algorithm |
| 194 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 195 | +# |
| 196 | + |
| 197 | +# Solve the Sinkhorn problem (just add reg parameter value) |
| 198 | +sol = ot.solve_sample(x1, x2, a, b, reg=1e-1) |
| 199 | + |
| 200 | +# get the OT plan and loss |
| 201 | +P_sink = sol.plan |
| 202 | +loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy) |
| 203 | +loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss |
| 204 | + |
| 205 | +# %% |
| 206 | +# The Sinkhorn algorithm solves the Entropic Regularized OT problem. The |
| 207 | +# regularization strength can be controlled with the :code:`reg` parameter. |
| 208 | +# The Sinkhorn algorithm can be faster than the exact OT solver for large |
| 209 | +# regularization strength but the solution is only an approximation of the |
| 210 | +# exact OT problem and the OT plan is not sparse. |
| 211 | +# |
| 212 | +# Plot the OT plan and dual potentials for Sinkhorn |
| 213 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 214 | +# |
| 215 | + |
| 216 | +pl.figure(1, (3, 3)) |
| 217 | +plot2D_samples_mat(x1, x2, P_sink) |
| 218 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 219 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 220 | +pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) |
| 221 | +pl.show() |
| 222 | + |
| 223 | +pl.figure(2, (3, 1.7)) |
| 224 | +pl.imshow(P_sink, cmap="Greys") |
| 225 | +pl.title("Sinkhorn OT plan") |
| 226 | +pl.show() |
0 commit comments