Skip to content

Commit aa03aaf

Browse files
committed
working on the guide
1 parent 3749263 commit aa03aaf

File tree

3 files changed

+107
-65
lines changed

3 files changed

+107
-65
lines changed

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def __getattr__(cls, name):
347347
}
348348

349349
sphinx_gallery_conf = {
350-
"examples_dirs": ["../../examples", "../../examples/da"],
350+
"examples_dirs": ["../../examples"],
351351
"gallery_dirs": "auto_examples",
352352
"filename_pattern": "plot_", # (?!barycenter_fgw)
353353
"nested_sections": False,

examples/plot_OT_2D_samples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
# %% EMD
6767

68-
G0 = ot.emd(a, b, M)
68+
G0 = ot.solve(M, a, b).plan
6969

7070
pl.figure(3)
7171
pl.imshow(G0, interpolation="nearest")

examples/plot_quickstart_guide.py

+105-63
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55
=============================================
66
77
8-
This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use
9-
here the new API of POT which is more flexible and allows to solve a wider range
10-
of problems with just a few functions. The old API is still available (the new
11-
one is a convenient wrapper around the old one) and we provide pointers to the
12-
old API when needed.
8+
Quickstart guide to the POT toolbox.
9+
10+
For better readability, only the use of POT is provided and the plotting code
11+
with matplotlib is hidden (but is available in the source file of the example).
12+
13+
.. note::
14+
We use here the new API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the new
15+
one is a convenient wrapper around the old one) and we provide pointers to the
16+
old API when needed.
1317
1418
"""
1519

1620
# Author: Remi Flamary
1721
#
1822
# License: MIT License
19-
# sphinx_gallery_thumbnail_number = 1
23+
# sphinx_gallery_thumbnail_number = 4
2024

2125
# Import necessary libraries
2226

@@ -43,18 +47,12 @@
4347
b = ot.utils.unif(n2) # weights of points in the target domain
4448

4549
x1 = np.random.randn(n1, 2)
46-
x1 /= (
47-
np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2
48-
) # project on the unit circle and scale
49-
x2 = np.random.randn(n2, 2)
50-
x2 /= (
51-
np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4
52-
) # project on the unit circle and scale
50+
x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2
5351

54-
# %%
55-
# Plot data
56-
# ~~~~~~~~~
52+
x2 = np.random.randn(n2, 2)
53+
x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4
5754

55+
# sphinx_gallery_start_ignore
5856
style = {"markeredgecolor": "k"}
5957

6058
pl.figure(1, (4, 4))
@@ -63,8 +61,13 @@
6361
pl.legend(loc=0)
6462
pl.title("Source and target distributions")
6563
pl.show()
64+
# sphinx_gallery_end_ignore
6665

6766
# %%
67+
# We illustrate above the simple example of two 2D distributions with 25 and 50
68+
# samples respectively located on circles. The weights of the samples are
69+
# uniform.
70+
#
6871
# Solving exact Optimal Transport
6972
# -------------------------------
7073
# Solve the Optimal Transport problem between the samples
@@ -88,31 +91,7 @@
8891

8992
print(f"OT loss = {loss:1.3f}")
9093

91-
# %%
92-
# We provide
93-
# the weights of the samples in the source and target domains :code:`a` and
94-
# :code:`b`. If not provided, the weights are assumed to be uniform.
95-
#
96-
# The :class:`ot.utils.OTResult` object contains the following attributes:
97-
#
98-
# - :code:`value`: the value of the OT problem
99-
# - :code:`plan`: the OT matrix
100-
# - :code:`potentials`: Dual potentials of the OT problem
101-
# - :code:`log`: log dictionary of the solver
102-
#
103-
# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where
104-
# :code:`P[i,j]` is the amount of mass
105-
# transported from :code:`x1[i]` to :code:`x2[j]`.
106-
#
107-
# The OT loss is the sum of the element-wise product of the OT matrix and the
108-
# cost matrix taken by default as the Squared Euclidean distance.
109-
#
110-
111-
# %%
112-
# Plot the OT plan and dual potentials
113-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
114-
#
115-
94+
# sphinx_gallery_start_ignore
11695
from ot.plot import plot2D_samples_mat
11796

11897
pl.figure(1, (8, 4))
@@ -134,6 +113,32 @@
134113
pl.imshow(P, cmap="Greys")
135114
pl.title("OT plan")
136115
pl.show()
116+
# sphinx_gallery_end_ignore
117+
118+
# %%
119+
# The figure above shows the Optimal Transport plan between the source and target
120+
# samples. The color intensity represents the amount of mass transported
121+
# between the samples. The dual potentials of the OT problem are also shown.
122+
#
123+
# The weights of the samples in the source and target domains :code:`a` and
124+
# :code:`b` are given to the function. If not provided, the weights are assumed
125+
# to be uniform See :func:`ot.solve_sample` for more details.
126+
#
127+
# The :class:`ot.utils.OTResult` object contains the following attributes:
128+
#
129+
# - :code:`value`: the value of the OT problem
130+
# - :code:`plan`: the OT matrix
131+
# - :code:`potentials`: Dual potentials of the OT problem
132+
# - :code:`log`: log dictionary of the solver
133+
#
134+
# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where
135+
# :code:`P[i,j]` is the amount of mass
136+
# transported from :code:`x1[i]` to :code:`x2[j]`.
137+
#
138+
# The OT loss is the sum of the element-wise product of the OT matrix and the
139+
# cost matrix taken by default as the Squared Euclidean distance.
140+
#
141+
137142

138143
# %%
139144
# Solve the Optimal Transport problem with a custom cost matrix
@@ -155,8 +160,21 @@
155160
# Compute the OT loss (equivalent to ot.solve(C).value)
156161
loss_city = np.sum(P_city * C)
157162

163+
# sphinx_gallery_start_ignore
164+
pl.figure(1, (3, 3))
165+
plot2D_samples_mat(x1, x2, P)
166+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
167+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
168+
pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city))
169+
170+
pl.figure(2, (3, 1.7))
171+
pl.imshow(P_city, cmap="Greys")
172+
pl.title("OT plan (Citybloc)")
173+
pl.show()
174+
# sphinx_gallery_end_ignore
175+
158176
# %%
159-
# Note that we show here how to sole the OT problem with a custom cost matrix
177+
# Note that we show here how to solve the OT problem with a custom cost matrix
160178
# with the more general :func:`ot.solve` function.
161179
# But the same can be done with the :func:`ot.solve_sample` function by passing
162180
# :code:`metric='cityblock'` as argument.
@@ -171,20 +189,9 @@
171189
# P = ot.emd(a, b, C)
172190
# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
173191
#
174-
# Plot the OT plan and dual potentials for other loss
175-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192+
# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample
176193
#
177194

178-
pl.figure(1, (3, 3))
179-
plot2D_samples_mat(x1, x2, P)
180-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
181-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
182-
pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city))
183-
184-
pl.figure(2, (3, 1.7))
185-
pl.imshow(P_city, cmap="Greys")
186-
pl.title("OT plan (Citybloc)")
187-
pl.show()
188195

189196
# %%
190197
# Sinkhorn and Regularized OT
@@ -202,25 +209,60 @@
202209
loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy)
203210
loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss
204211

212+
# sphinx_gallery_start_ignore
213+
pl.figure(1, (3, 3))
214+
plot2D_samples_mat(x1, x2, P_sink)
215+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
216+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
217+
pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink))
218+
pl.show()
219+
220+
pl.figure(2, (3, 1.7))
221+
pl.imshow(P_sink, cmap="Greys")
222+
pl.title("Sinkhorn OT plan")
223+
pl.show()
224+
# sphinx_gallery_end_ignore
205225
# %%
206226
# The Sinkhorn algorithm solves the Entropic Regularized OT problem. The
207227
# regularization strength can be controlled with the :code:`reg` parameter.
208228
# The Sinkhorn algorithm can be faster than the exact OT solver for large
209229
# regularization strength but the solution is only an approximation of the
210230
# exact OT problem and the OT plan is not sparse.
211-
#
212-
# Plot the OT plan and dual potentials for Sinkhorn
213-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231+
232+
# %%
233+
# Solve the Regularized OT problem with other regularizations
234+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214235
#
215236

216-
pl.figure(1, (3, 3))
237+
# Use quadratic regularization
238+
P_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").plan
239+
240+
loss_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").value
241+
242+
# sphinx_gallery_start_ignore
243+
pl.figure(1, (9, 3))
244+
245+
pl.subplot(1, 3, 1)
246+
plot2D_samples_mat(x1, x2, P)
247+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
248+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
249+
pl.title("OT plan loss={:.3f}".format(loss))
250+
251+
pl.subplot(1, 3, 2)
217252
plot2D_samples_mat(x1, x2, P_sink)
218253
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
219254
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
220-
pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink))
221-
pl.show()
255+
pl.title("Sinkhorn plan loss={:.3f}".format(loss_sink))
222256

223-
pl.figure(2, (3, 1.7))
224-
pl.imshow(P_sink, cmap="Greys")
225-
pl.title("Sinkhorn OT plan")
257+
pl.subplot(1, 3, 3)
258+
plot2D_samples_mat(x1, x2, P_quad)
259+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
260+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
261+
pl.title("Quadratic plan loss={:.3f}".format(loss_quad))
226262
pl.show()
263+
# sphinx_gallery_end_ignore
264+
# %%
265+
# We plot above the OT plans obtained with different regularizations. The
266+
# quadratic regularization is another common choice for regularized OT and
267+
# preserves the sparsity of the OT plan.
268+
#

0 commit comments

Comments
 (0)