Skip to content

Commit 3749263

Browse files
committed
premier jet quckstart guide
1 parent 3c7ca56 commit 3749263

File tree

1 file changed

+206
-5
lines changed

1 file changed

+206
-5
lines changed

examples/plot_quickstart_guide.py

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
=============================================
66
77
8-
This is a quickstart guide to the Python Optimal Transport (POT) toolbox.
8+
This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use
9+
here the new API of POT which is more flexible and allows to solve a wider range
10+
of problems with just a few functions. The old API is still available (the new
11+
one is a convenient wrapper around the old one) and we provide pointers to the
12+
old API when needed.
913
1014
"""
1115

@@ -14,12 +18,209 @@
1418
# License: MIT License
1519
# sphinx_gallery_thumbnail_number = 1
1620

21+
# Import necessary libraries
22+
23+
import numpy as np
24+
import pylab as pl
25+
26+
import ot
27+
28+
1729
# %%
18-
# Simple example
30+
# Example data
1931
# --------------
2032
#
33+
# Data generation
34+
# ~~~~~~~~~~~~~~~
2135

22-
import numpy as np # always need it
23-
import pylab as pl # for the plots
36+
# Problem size
37+
n1 = 25
38+
n2 = 50
2439

25-
import ot
40+
# Generate random data
41+
np.random.seed(0)
42+
a = ot.utils.unif(n1) # weights of points in the source domain
43+
b = ot.utils.unif(n2) # weights of points in the target domain
44+
45+
x1 = np.random.randn(n1, 2)
46+
x1 /= (
47+
np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2
48+
) # project on the unit circle and scale
49+
x2 = np.random.randn(n2, 2)
50+
x2 /= (
51+
np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4
52+
) # project on the unit circle and scale
53+
54+
# %%
55+
# Plot data
56+
# ~~~~~~~~~
57+
58+
style = {"markeredgecolor": "k"}
59+
60+
pl.figure(1, (4, 4))
61+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
62+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
63+
pl.legend(loc=0)
64+
pl.title("Source and target distributions")
65+
pl.show()
66+
67+
# %%
68+
# Solving exact Optimal Transport
69+
# -------------------------------
70+
# Solve the Optimal Transport problem between the samples
71+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72+
#
73+
# The :func:`ot.solve_sample` function can be used to solve the Optimal Transport problem
74+
# between two sets of samples. The function takes as its two first arguments the
75+
# positions of the source and target samples, and returns an :class:`ot.utils.OTResult` object.
76+
77+
# Solve the OT problem
78+
sol = ot.solve_sample(x1, x2, a, b)
79+
80+
# get the OT plan
81+
P = sol.plan
82+
83+
# get the OT loss
84+
loss = sol.value
85+
86+
# get the dual potentials
87+
alpha, beta = sol.potentials
88+
89+
print(f"OT loss = {loss:1.3f}")
90+
91+
# %%
92+
# We provide
93+
# the weights of the samples in the source and target domains :code:`a` and
94+
# :code:`b`. If not provided, the weights are assumed to be uniform.
95+
#
96+
# The :class:`ot.utils.OTResult` object contains the following attributes:
97+
#
98+
# - :code:`value`: the value of the OT problem
99+
# - :code:`plan`: the OT matrix
100+
# - :code:`potentials`: Dual potentials of the OT problem
101+
# - :code:`log`: log dictionary of the solver
102+
#
103+
# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where
104+
# :code:`P[i,j]` is the amount of mass
105+
# transported from :code:`x1[i]` to :code:`x2[j]`.
106+
#
107+
# The OT loss is the sum of the element-wise product of the OT matrix and the
108+
# cost matrix taken by default as the Squared Euclidean distance.
109+
#
110+
111+
# %%
112+
# Plot the OT plan and dual potentials
113+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
114+
#
115+
116+
from ot.plot import plot2D_samples_mat
117+
118+
pl.figure(1, (8, 4))
119+
120+
pl.subplot(1, 2, 1)
121+
plot2D_samples_mat(x1, x2, P)
122+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
123+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
124+
pl.title("OT plan P loss={:.3f}".format(loss))
125+
126+
pl.subplot(1, 2, 2)
127+
pl.scatter(x1[:, 0], x1[:, 1], c=alpha, cmap="viridis", edgecolors="k")
128+
pl.scatter(x2[:, 0], x2[:, 1], c=beta, cmap="plasma", edgecolors="k")
129+
pl.title("Dual potentials")
130+
pl.show()
131+
132+
133+
pl.figure(2, (3, 1.7))
134+
pl.imshow(P, cmap="Greys")
135+
pl.title("OT plan")
136+
pl.show()
137+
138+
# %%
139+
# Solve the Optimal Transport problem with a custom cost matrix
140+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141+
#
142+
# The cost matrix can be customized by passing it to the more general
143+
# :func:`ot.solve` function. The cost matrix should be a matrix of size
144+
# :code:`(n1, n2)` where :code:`C[i,j]` is the cost of transporting mass from
145+
# :code:`x1[i]` to :code:`x2[j]`.
146+
#
147+
# In this example, we use the Citybloc distance as the cost matrix.
148+
149+
# Compute the cost matrix
150+
C = ot.dist(x1, x2, metric="cityblock")
151+
152+
# Solve the OT problem with the custom cost matrix
153+
P_city = ot.solve(C).plan
154+
155+
# Compute the OT loss (equivalent to ot.solve(C).value)
156+
loss_city = np.sum(P_city * C)
157+
158+
# %%
159+
# Note that we show here how to sole the OT problem with a custom cost matrix
160+
# with the more general :func:`ot.solve` function.
161+
# But the same can be done with the :func:`ot.solve_sample` function by passing
162+
# :code:`metric='cityblock'` as argument.
163+
#
164+
# .. note::
165+
# The examples above use the new API of POT. The old API is still available
166+
# and and OT plan and loss can be computed with the :func:`ot.emd` and
167+
# the :func:`ot.emd2` functions as below:
168+
#
169+
# .. code-block:: python
170+
#
171+
# P = ot.emd(a, b, C)
172+
# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
173+
#
174+
# Plot the OT plan and dual potentials for other loss
175+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176+
#
177+
178+
pl.figure(1, (3, 3))
179+
plot2D_samples_mat(x1, x2, P)
180+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
181+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
182+
pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city))
183+
184+
pl.figure(2, (3, 1.7))
185+
pl.imshow(P_city, cmap="Greys")
186+
pl.title("OT plan (Citybloc)")
187+
pl.show()
188+
189+
# %%
190+
# Sinkhorn and Regularized OT
191+
# ---------------------------
192+
#
193+
# Solve Entropic Regularized OT with Sinkhorn algorithm
194+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
195+
#
196+
197+
# Solve the Sinkhorn problem (just add reg parameter value)
198+
sol = ot.solve_sample(x1, x2, a, b, reg=1e-1)
199+
200+
# get the OT plan and loss
201+
P_sink = sol.plan
202+
loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy)
203+
loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss
204+
205+
# %%
206+
# The Sinkhorn algorithm solves the Entropic Regularized OT problem. The
207+
# regularization strength can be controlled with the :code:`reg` parameter.
208+
# The Sinkhorn algorithm can be faster than the exact OT solver for large
209+
# regularization strength but the solution is only an approximation of the
210+
# exact OT problem and the OT plan is not sparse.
211+
#
212+
# Plot the OT plan and dual potentials for Sinkhorn
213+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214+
#
215+
216+
pl.figure(1, (3, 3))
217+
plot2D_samples_mat(x1, x2, P_sink)
218+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
219+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
220+
pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink))
221+
pl.show()
222+
223+
pl.figure(2, (3, 1.7))
224+
pl.imshow(P_sink, cmap="Greys")
225+
pl.title("Sinkhorn OT plan")
226+
pl.show()

0 commit comments

Comments
 (0)