|
5 | 5 | =============================================
|
6 | 6 |
|
7 | 7 |
|
8 |
| -This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use |
9 |
| -here the new API of POT which is more flexible and allows to solve a wider range |
10 |
| -of problems with just a few functions. The old API is still available (the new |
11 |
| -one is a convenient wrapper around the old one) and we provide pointers to the |
12 |
| -old API when needed. |
| 8 | +Quickstart guide to the POT toolbox. |
| 9 | +
|
| 10 | +For better readability, only the use of POT is provided and the plotting code |
| 11 | +with matplotlib is hidden (but is available in the source file of the example). |
| 12 | +
|
| 13 | +.. note:: |
| 14 | + We use here the new API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the new |
| 15 | + one is a convenient wrapper around the old one) and we provide pointers to the |
| 16 | + old API when needed. |
13 | 17 |
|
14 | 18 | """
|
15 | 19 |
|
16 | 20 | # Author: Remi Flamary
|
17 | 21 | #
|
18 | 22 | # License: MIT License
|
19 |
| -# sphinx_gallery_thumbnail_number = 1 |
| 23 | +# sphinx_gallery_thumbnail_number = 4 |
20 | 24 |
|
21 | 25 | # Import necessary libraries
|
22 | 26 |
|
|
43 | 47 | b = ot.utils.unif(n2) # weights of points in the target domain
|
44 | 48 |
|
45 | 49 | x1 = np.random.randn(n1, 2)
|
46 |
| -x1 /= ( |
47 |
| - np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 |
48 |
| -) # project on the unit circle and scale |
49 |
| -x2 = np.random.randn(n2, 2) |
50 |
| -x2 /= ( |
51 |
| - np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 |
52 |
| -) # project on the unit circle and scale |
| 50 | +x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 |
53 | 51 |
|
54 |
| -# %% |
55 |
| -# Plot data |
56 |
| -# ~~~~~~~~~ |
| 52 | +x2 = np.random.randn(n2, 2) |
| 53 | +x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 |
57 | 54 |
|
| 55 | +# sphinx_gallery_start_ignore |
58 | 56 | style = {"markeredgecolor": "k"}
|
59 | 57 |
|
60 | 58 | pl.figure(1, (4, 4))
|
|
63 | 61 | pl.legend(loc=0)
|
64 | 62 | pl.title("Source and target distributions")
|
65 | 63 | pl.show()
|
| 64 | +# sphinx_gallery_end_ignore |
66 | 65 |
|
67 | 66 | # %%
|
| 67 | +# We illustrate above the simple example of two 2D distributions with 25 and 50 |
| 68 | +# samples respectively located on circles. The weights of the samples are |
| 69 | +# uniform. |
| 70 | +# |
68 | 71 | # Solving exact Optimal Transport
|
69 | 72 | # -------------------------------
|
70 | 73 | # Solve the Optimal Transport problem between the samples
|
|
88 | 91 |
|
89 | 92 | print(f"OT loss = {loss:1.3f}")
|
90 | 93 |
|
91 |
| -# %% |
92 |
| -# We provide |
93 |
| -# the weights of the samples in the source and target domains :code:`a` and |
94 |
| -# :code:`b`. If not provided, the weights are assumed to be uniform. |
95 |
| -# |
96 |
| -# The :class:`ot.utils.OTResult` object contains the following attributes: |
97 |
| -# |
98 |
| -# - :code:`value`: the value of the OT problem |
99 |
| -# - :code:`plan`: the OT matrix |
100 |
| -# - :code:`potentials`: Dual potentials of the OT problem |
101 |
| -# - :code:`log`: log dictionary of the solver |
102 |
| -# |
103 |
| -# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where |
104 |
| -# :code:`P[i,j]` is the amount of mass |
105 |
| -# transported from :code:`x1[i]` to :code:`x2[j]`. |
106 |
| -# |
107 |
| -# The OT loss is the sum of the element-wise product of the OT matrix and the |
108 |
| -# cost matrix taken by default as the Squared Euclidean distance. |
109 |
| -# |
110 |
| - |
111 |
| -# %% |
112 |
| -# Plot the OT plan and dual potentials |
113 |
| -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
114 |
| -# |
115 |
| - |
| 94 | +# sphinx_gallery_start_ignore |
116 | 95 | from ot.plot import plot2D_samples_mat
|
117 | 96 |
|
118 | 97 | pl.figure(1, (8, 4))
|
|
134 | 113 | pl.imshow(P, cmap="Greys")
|
135 | 114 | pl.title("OT plan")
|
136 | 115 | pl.show()
|
| 116 | +# sphinx_gallery_end_ignore |
| 117 | + |
| 118 | +# %% |
| 119 | +# The figure above shows the Optimal Transport plan between the source and target |
| 120 | +# samples. The color intensity represents the amount of mass transported |
| 121 | +# between the samples. The dual potentials of the OT problem are also shown. |
| 122 | +# |
| 123 | +# The weights of the samples in the source and target domains :code:`a` and |
| 124 | +# :code:`b` are given to the function. If not provided, the weights are assumed |
| 125 | +# to be uniform See :func:`ot.solve_sample` for more details. |
| 126 | +# |
| 127 | +# The :class:`ot.utils.OTResult` object contains the following attributes: |
| 128 | +# |
| 129 | +# - :code:`value`: the value of the OT problem |
| 130 | +# - :code:`plan`: the OT matrix |
| 131 | +# - :code:`potentials`: Dual potentials of the OT problem |
| 132 | +# - :code:`log`: log dictionary of the solver |
| 133 | +# |
| 134 | +# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where |
| 135 | +# :code:`P[i,j]` is the amount of mass |
| 136 | +# transported from :code:`x1[i]` to :code:`x2[j]`. |
| 137 | +# |
| 138 | +# The OT loss is the sum of the element-wise product of the OT matrix and the |
| 139 | +# cost matrix taken by default as the Squared Euclidean distance. |
| 140 | +# |
| 141 | + |
137 | 142 |
|
138 | 143 | # %%
|
139 | 144 | # Solve the Optimal Transport problem with a custom cost matrix
|
|
155 | 160 | # Compute the OT loss (equivalent to ot.solve(C).value)
|
156 | 161 | loss_city = np.sum(P_city * C)
|
157 | 162 |
|
| 163 | +# sphinx_gallery_start_ignore |
| 164 | +pl.figure(1, (3, 3)) |
| 165 | +plot2D_samples_mat(x1, x2, P) |
| 166 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 167 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 168 | +pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) |
| 169 | + |
| 170 | +pl.figure(2, (3, 1.7)) |
| 171 | +pl.imshow(P_city, cmap="Greys") |
| 172 | +pl.title("OT plan (Citybloc)") |
| 173 | +pl.show() |
| 174 | +# sphinx_gallery_end_ignore |
| 175 | + |
158 | 176 | # %%
|
159 |
| -# Note that we show here how to sole the OT problem with a custom cost matrix |
| 177 | +# Note that we show here how to solve the OT problem with a custom cost matrix |
160 | 178 | # with the more general :func:`ot.solve` function.
|
161 | 179 | # But the same can be done with the :func:`ot.solve_sample` function by passing
|
162 | 180 | # :code:`metric='cityblock'` as argument.
|
|
171 | 189 | # P = ot.emd(a, b, C)
|
172 | 190 | # loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
|
173 | 191 | #
|
174 |
| -# Plot the OT plan and dual potentials for other loss |
175 |
| -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 192 | +# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample |
176 | 193 | #
|
177 | 194 |
|
178 |
| -pl.figure(1, (3, 3)) |
179 |
| -plot2D_samples_mat(x1, x2, P) |
180 |
| -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
181 |
| -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
182 |
| -pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) |
183 |
| - |
184 |
| -pl.figure(2, (3, 1.7)) |
185 |
| -pl.imshow(P_city, cmap="Greys") |
186 |
| -pl.title("OT plan (Citybloc)") |
187 |
| -pl.show() |
188 | 195 |
|
189 | 196 | # %%
|
190 | 197 | # Sinkhorn and Regularized OT
|
|
202 | 209 | loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy)
|
203 | 210 | loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss
|
204 | 211 |
|
| 212 | +# sphinx_gallery_start_ignore |
| 213 | +pl.figure(1, (3, 3)) |
| 214 | +plot2D_samples_mat(x1, x2, P_sink) |
| 215 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 216 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 217 | +pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) |
| 218 | +pl.show() |
| 219 | + |
| 220 | +pl.figure(2, (3, 1.7)) |
| 221 | +pl.imshow(P_sink, cmap="Greys") |
| 222 | +pl.title("Sinkhorn OT plan") |
| 223 | +pl.show() |
| 224 | +# sphinx_gallery_end_ignore |
205 | 225 | # %%
|
206 | 226 | # The Sinkhorn algorithm solves the Entropic Regularized OT problem. The
|
207 | 227 | # regularization strength can be controlled with the :code:`reg` parameter.
|
208 | 228 | # The Sinkhorn algorithm can be faster than the exact OT solver for large
|
209 | 229 | # regularization strength but the solution is only an approximation of the
|
210 | 230 | # exact OT problem and the OT plan is not sparse.
|
211 |
| -# |
212 |
| -# Plot the OT plan and dual potentials for Sinkhorn |
213 |
| -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 231 | + |
| 232 | +# %% |
| 233 | +# Solve the Regularized OT problem with other regularizations |
| 234 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
214 | 235 | #
|
215 | 236 |
|
216 |
| -pl.figure(1, (3, 3)) |
| 237 | +# Use quadratic regularization |
| 238 | +P_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").plan |
| 239 | + |
| 240 | +loss_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").value |
| 241 | + |
| 242 | +# sphinx_gallery_start_ignore |
| 243 | +pl.figure(1, (9, 3)) |
| 244 | + |
| 245 | +pl.subplot(1, 3, 1) |
| 246 | +plot2D_samples_mat(x1, x2, P) |
| 247 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 248 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 249 | +pl.title("OT plan loss={:.3f}".format(loss)) |
| 250 | + |
| 251 | +pl.subplot(1, 3, 2) |
217 | 252 | plot2D_samples_mat(x1, x2, P_sink)
|
218 | 253 | pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
|
219 | 254 | pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
|
220 |
| -pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) |
221 |
| -pl.show() |
| 255 | +pl.title("Sinkhorn plan loss={:.3f}".format(loss_sink)) |
222 | 256 |
|
223 |
| -pl.figure(2, (3, 1.7)) |
224 |
| -pl.imshow(P_sink, cmap="Greys") |
225 |
| -pl.title("Sinkhorn OT plan") |
| 257 | +pl.subplot(1, 3, 3) |
| 258 | +plot2D_samples_mat(x1, x2, P_quad) |
| 259 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 260 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 261 | +pl.title("Quadratic plan loss={:.3f}".format(loss_quad)) |
226 | 262 | pl.show()
|
| 263 | +# sphinx_gallery_end_ignore |
| 264 | +# %% |
| 265 | +# We plot above the OT plans obtained with different regularizations. The |
| 266 | +# quadratic regularization is another common choice for regularized OT and |
| 267 | +# preserves the sparsity of the OT plan. |
| 268 | +# |
0 commit comments