Skip to content

Commit a8acfc3

Browse files
aboisbunonAurelie Boisbunon
and
Aurelie Boisbunon
authored
[WIP] add introductory example of OT, EMD and Sinkhorn (#191)
* add introductory example of OT, EMD and Sinkhorn * improve figure and try complying with pep8 * autopep8 * change markdown elements to rst * try solving issue with images * fix issue with images * add a section on varying the sinkhorn hyperparameter * add Sinkhorn algorithm and discussion for comparison between EMD and Sinkhorn * autopep8 again * add subsections and modify figure sizes/shapes * fix bug with print * correct some typos * remove computational time comparison * autopep8 again... Co-authored-by: Aurelie Boisbunon <[email protected]>
1 parent 24a7a04 commit a8acfc3

File tree

4 files changed

+373
-0
lines changed

4 files changed

+373
-0
lines changed

data/manhattan.npz

2.21 MB
Binary file not shown.
298 KB
Loading
36.3 KB
Loading

examples/plot_Intro_OT.py

+373
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
# coding: utf-8
2+
"""
3+
=============================================
4+
Introduction to Optimal Transport with Python
5+
=============================================
6+
7+
This example gives an introduction on how to use Optimal Transport in Python.
8+
9+
"""
10+
11+
# Author: Remi Flamary, Nicolas Courty, Aurelie Boisbunon
12+
#
13+
# License: MIT License
14+
# sphinx_gallery_thumbnail_number = 1
15+
16+
##############################################################################
17+
# POT Python Optimal Transport Toolbox
18+
# ------------------------------------
19+
#
20+
# POT installation
21+
# ```````````````````
22+
#
23+
# * Install with pip::
24+
#
25+
# pip install pot
26+
# * Install with conda::
27+
#
28+
# conda install -c conda-forge pot
29+
#
30+
# Import the toolbox
31+
# ```````````````````
32+
#
33+
34+
import numpy as np # always need it
35+
import pylab as pl # do the plots
36+
37+
import ot # ot
38+
39+
import time
40+
41+
##############################################################################
42+
# Getting help
43+
# `````````````
44+
#
45+
# Online documentation : `<https://pythonot.github.io/all.html>`_
46+
#
47+
# Or inline help:
48+
#
49+
50+
help(ot.dist)
51+
52+
53+
##############################################################################
54+
# First OT Problem
55+
# ----------------
56+
#
57+
# We will solve the Bakery/Cafés problem of transporting croissants from a
58+
# number of Bakeries to Cafés in a City (in this case Manhattan). We did a
59+
# quick google map search in Manhattan for bakeries and Cafés:
60+
#
61+
# .. image:: images/bak.png
62+
# :align: center
63+
# :alt: bakery-cafe-manhattan
64+
# :width: 600px
65+
# :height: 280px
66+
#
67+
# We extracted from this search their positions and generated fictional
68+
# production and sale number (that both sum to the same value).
69+
#
70+
# We have acess to the position of Bakeries ``bakery_pos`` and their
71+
# respective production ``bakery_prod`` which describe the source
72+
# distribution. The Cafés where the croissants are sold are defined also by
73+
# their position ``cafe_pos`` and ``cafe_prod``, and describe the target
74+
# distribution. For fun we also provide a
75+
# map ``Imap`` that will illustrate the position of these shops in the city.
76+
#
77+
#
78+
# Now we load the data
79+
#
80+
#
81+
82+
data = np.load('../data/manhattan.npz')
83+
84+
bakery_pos = data['bakery_pos']
85+
bakery_prod = data['bakery_prod']
86+
cafe_pos = data['cafe_pos']
87+
cafe_prod = data['cafe_prod']
88+
Imap = data['Imap']
89+
90+
print('Bakery production: {}'.format(bakery_prod))
91+
print('Cafe sale: {}'.format(cafe_prod))
92+
print('Total croissants : {}'.format(cafe_prod.sum()))
93+
94+
95+
##############################################################################
96+
# Plotting bakeries in the city
97+
# -----------------------------
98+
#
99+
# Next we plot the position of the bakeries and cafés on the map. The size of
100+
# the circle is proportional to their production.
101+
#
102+
103+
pl.figure(1, (7, 6))
104+
pl.clf()
105+
pl.imshow(Imap, interpolation='bilinear') # plot the map
106+
pl.scatter(bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c='r', ec='k', label='Bakeries')
107+
pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c='b', ec='k', label='Cafés')
108+
pl.legend()
109+
pl.title('Manhattan Bakeries and Cafés')
110+
111+
112+
##############################################################################
113+
# Cost matrix
114+
# -----------
115+
#
116+
#
117+
# We can now compute the cost matrix between the bakeries and the cafés, which
118+
# will be the transport cost matrix. This can be done using the
119+
# `ot.dist <https://pythonot.github.io/all.html#ot.dist>`_ function that
120+
# defaults to squared Euclidean distance but can return other things such as
121+
# cityblock (or Manhattan distance).
122+
#
123+
124+
C = ot.dist(bakery_pos, cafe_pos)
125+
126+
labels = [str(i) for i in range(len(bakery_prod))]
127+
f = pl.figure(2, (14, 7))
128+
pl.clf()
129+
pl.subplot(121)
130+
pl.imshow(Imap, interpolation='bilinear') # plot the map
131+
for i in range(len(cafe_pos)):
132+
pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
133+
fontsize=14, fontweight='bold', ha='center', va='center')
134+
for i in range(len(bakery_pos)):
135+
pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
136+
fontsize=14, fontweight='bold', ha='center', va='center')
137+
pl.title('Manhattan Bakeries and Cafés')
138+
139+
ax = pl.subplot(122)
140+
im = pl.imshow(C, cmap="coolwarm")
141+
pl.title('Cost matrix')
142+
cbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True)
143+
cbar.ax.set_ylabel("cost", rotation=-90, va="bottom")
144+
145+
pl.xlabel('Cafés')
146+
pl.ylabel('Bakeries')
147+
pl.tight_layout()
148+
149+
150+
##############################################################################
151+
# The red cells in the matrix image show the bakeries and cafés that are
152+
# further away, and thus more costly to transport from one to the other, while
153+
# the blue ones show those that are very close to each other, with respect to
154+
# the squared Euclidean distance.
155+
156+
157+
##############################################################################
158+
# Solving the OT problem with `ot.emd <https://pythonot.github.io/all.html#ot.emd>`_
159+
# -----------------------------------------------------------------------------------
160+
161+
start = time.time()
162+
ot_emd = ot.emd(bakery_prod, cafe_prod, C)
163+
time_emd = time.time() - start
164+
165+
##############################################################################
166+
# The function returns the transport matrix, which we can then visualize (next section).
167+
168+
##############################################################################
169+
# Transportation plan vizualization
170+
# `````````````````````````````````
171+
#
172+
# A good vizualization of the OT matrix in the 2D plane is to denote the
173+
# transportation of mass between a Bakery and a Café by a line. This can easily
174+
# be done with a double ``for`` loop.
175+
#
176+
# In order to make it more interpretable one can also use the ``alpha``
177+
# parameter of plot and set it to ``alpha=G[i,j]/G.max()``.
178+
179+
# Plot the matrix and the map
180+
f = pl.figure(3, (14, 7))
181+
pl.clf()
182+
pl.subplot(121)
183+
pl.imshow(Imap, interpolation='bilinear') # plot the map
184+
for i in range(len(bakery_pos)):
185+
for j in range(len(cafe_pos)):
186+
pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], [bakery_pos[i, 1], cafe_pos[j, 1]],
187+
'-k', lw=3. * ot_emd[i, j] / ot_emd.max())
188+
for i in range(len(cafe_pos)):
189+
pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', fontsize=14,
190+
fontweight='bold', ha='center', va='center')
191+
for i in range(len(bakery_pos)):
192+
pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', fontsize=14,
193+
fontweight='bold', ha='center', va='center')
194+
pl.title('Manhattan Bakeries and Cafés')
195+
196+
ax = pl.subplot(122)
197+
im = pl.imshow(ot_emd)
198+
for i in range(len(bakery_prod)):
199+
for j in range(len(cafe_prod)):
200+
text = ax.text(j, i, '{0:g}'.format(ot_emd[i, j]),
201+
ha="center", va="center", color="w")
202+
pl.title('Transport matrix')
203+
204+
pl.xlabel('Cafés')
205+
pl.ylabel('Bakeries')
206+
pl.tight_layout()
207+
208+
##############################################################################
209+
# The transport matrix gives the number of croissants that can be transported
210+
# from each bakery to each café. We can see that the bakeries only need to
211+
# transport croissants to one or two cafés, the transport matrix being very
212+
# sparse.
213+
214+
##############################################################################
215+
# OT loss and dual variables
216+
# --------------------------
217+
#
218+
# The resulting wasserstein loss loss is of the form:
219+
#
220+
# .. math::
221+
# W=\sum_{i,j}\gamma_{i,j}C_{i,j}
222+
#
223+
# where :math:`\gamma` is the optimal transport matrix.
224+
#
225+
226+
W = np.sum(ot_emd * C)
227+
print('Wasserstein loss (EMD) = {0:.2f}'.format(W))
228+
229+
##############################################################################
230+
# Regularized OT with Sinkhorn
231+
# ----------------------------
232+
#
233+
# The Sinkhorn algorithm is very simple to code. You can implement it directly
234+
# using the following pseudo-code
235+
#
236+
# .. image:: images/sinkhorn.png
237+
# :align: center
238+
# :alt: Sinkhorn algorithm
239+
# :width: 440px
240+
# :height: 240px
241+
#
242+
# In this algorithm, :math:`\oslash` corresponds to the element-wise division.
243+
#
244+
# An alternative is to use the POT toolbox with
245+
# `ot.sinkhorn <https://pythonot.github.io/all.html#ot.sinkhorn>`_
246+
#
247+
# Be careful of numerical problems. A good pre-processing for Sinkhorn is to
248+
# divide the cost matrix ``C`` by its maximum value.
249+
250+
##############################################################################
251+
# Algorithm
252+
# `````````
253+
254+
# Compute Sinkhorn transport matrix from algorithm
255+
reg = 0.1
256+
K = np.exp(-C / C.max() / reg)
257+
nit = 100
258+
u = np.ones((len(bakery_prod), ))
259+
for i in range(1, nit):
260+
v = cafe_prod / np.dot(K.T, u)
261+
u = bakery_prod / (np.dot(K, v))
262+
ot_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))
263+
264+
# Compute Sinkhorn transport matrix with POT
265+
ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max())
266+
267+
# Difference between the 2
268+
print('Difference between algo and ot.sinkhorn = {0:.2g}'.format(np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2))))
269+
270+
##############################################################################
271+
# Plot the matrix and the map
272+
# ```````````````````````````
273+
274+
print('Min. of Sinkhorn\'s transport matrix = {0:.2g}'.format(np.min(ot_sinkhorn)))
275+
276+
f = pl.figure(4, (13, 6))
277+
pl.clf()
278+
pl.subplot(121)
279+
pl.imshow(Imap, interpolation='bilinear') # plot the map
280+
for i in range(len(bakery_pos)):
281+
for j in range(len(cafe_pos)):
282+
pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]],
283+
[bakery_pos[i, 1], cafe_pos[j, 1]],
284+
'-k', lw=3. * ot_sinkhorn[i, j] / ot_sinkhorn.max())
285+
for i in range(len(cafe_pos)):
286+
pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
287+
fontsize=14, fontweight='bold', ha='center', va='center')
288+
for i in range(len(bakery_pos)):
289+
pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
290+
fontsize=14, fontweight='bold', ha='center', va='center')
291+
pl.title('Manhattan Bakeries and Cafés')
292+
293+
ax = pl.subplot(122)
294+
im = pl.imshow(ot_sinkhorn)
295+
for i in range(len(bakery_prod)):
296+
for j in range(len(cafe_prod)):
297+
text = ax.text(j, i, np.round(ot_sinkhorn[i, j], 1),
298+
ha="center", va="center", color="w")
299+
pl.title('Transport matrix')
300+
301+
pl.xlabel('Cafés')
302+
pl.ylabel('Bakeries')
303+
pl.tight_layout()
304+
305+
306+
##############################################################################
307+
# We notice right away that the matrix is not sparse at all with Sinkhorn,
308+
# each bakery delivering croissants to all 5 cafés with that solution. Also,
309+
# this solution gives a transport with fractions, which does not make sense
310+
# in the case of croissants. This was not the case with EMD.
311+
312+
##############################################################################
313+
# Varying the regularization parameter in Sinkhorn
314+
# ````````````````````````````````````````````````
315+
#
316+
317+
reg_parameter = np.logspace(-3, 0, 20)
318+
W_sinkhorn_reg = np.zeros((len(reg_parameter), ))
319+
time_sinkhorn_reg = np.zeros((len(reg_parameter), ))
320+
321+
f = pl.figure(5, (14, 5))
322+
pl.clf()
323+
max_ot = 100 # plot matrices with the same colorbar
324+
for k in range(len(reg_parameter)):
325+
start = time.time()
326+
ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max())
327+
time_sinkhorn_reg[k] = time.time() - start
328+
329+
if k % 4 == 0 and k > 0: # we only plot a few
330+
ax = pl.subplot(1, 5, k / 4)
331+
im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
332+
pl.title('reg={0:.2g}'.format(reg_parameter[k]))
333+
pl.xlabel('Cafés')
334+
pl.ylabel('Bakeries')
335+
336+
# Compute the Wasserstein loss for Sinkhorn, and compare with EMD
337+
W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C)
338+
pl.tight_layout()
339+
340+
341+
##############################################################################
342+
# This series of graph shows that the solution of Sinkhorn starts with something
343+
# very similar to EMD (although not sparse) for very small values of the
344+
# regularization parameter, and tends to a more uniform solution as the
345+
# regularization parameter increases.
346+
#
347+
348+
##############################################################################
349+
# Wasserstein loss and computational time
350+
# ```````````````````````````````````````
351+
#
352+
353+
# Plot the matrix and the map
354+
f = pl.figure(6, (4, 4))
355+
pl.clf()
356+
pl.title("Comparison between Sinkhorn and EMD")
357+
358+
pl.plot(reg_parameter, W_sinkhorn_reg, 'o', label="Sinkhorn")
359+
XLim = pl.xlim()
360+
pl.plot(XLim, [W, W], '--k', label="EMD")
361+
pl.legend()
362+
pl.xlabel("reg")
363+
pl.ylabel("Wasserstein loss")
364+
365+
##############################################################################
366+
# In this last graph, we show the impact of the regularization parameter on
367+
# the Wasserstein loss. We can see that higher
368+
# values of ``reg`` leads to a much higher Wasserstein loss.
369+
#
370+
# The Wasserstein loss of EMD is displayed for
371+
# comparison. The Wasserstein loss of Sinkhorn can be a little lower than that
372+
# of EMD for low values of ``reg``, but it quickly gets much higher.
373+
#

0 commit comments

Comments
 (0)