Skip to content

Commit dd885bc

Browse files
committed
add stuff
1 parent aa03aaf commit dd885bc

File tree

1 file changed

+263
-11
lines changed

1 file changed

+263
-11
lines changed

examples/plot_quickstart_guide.py

+263-11
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131

3232

3333
# %%
34-
# Example data
34+
# Data generation
3535
# --------------
3636
#
37-
# Data generation
38-
# ~~~~~~~~~~~~~~~
37+
# We first generate two sets of samples in 2D that 25 and 50
38+
# samples respectively located on circles. The weights of the samples are
39+
# uniform.
3940

4041
# Problem size
4142
n1 = 25
@@ -64,9 +65,6 @@
6465
# sphinx_gallery_end_ignore
6566

6667
# %%
67-
# We illustrate above the simple example of two 2D distributions with 25 and 50
68-
# samples respectively located on circles. The weights of the samples are
69-
# uniform.
7068
#
7169
# Solving exact Optimal Transport
7270
# -------------------------------
@@ -189,15 +187,13 @@
189187
# P = ot.emd(a, b, C)
190188
# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
191189
#
192-
# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample
193-
#
194190

195191

196192
# %%
197193
# Sinkhorn and Regularized OT
198194
# ---------------------------
199195
#
200-
# Solve Entropic Regularized OT with Sinkhorn algorithm
196+
# Entropic OT with Sinkhorn algorithm
201197
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202198
#
203199

@@ -230,8 +226,8 @@
230226
# exact OT problem and the OT plan is not sparse.
231227

232228
# %%
233-
# Solve the Regularized OT problem with other regularizations
234-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229+
# Quadratic Regularized OT
230+
# ~~~~~~~~~~~~~~~~~~~~~~~~~
235231
#
236232

237233
# Use quadratic regularization
@@ -266,3 +262,259 @@
266262
# quadratic regularization is another common choice for regularized OT and
267263
# preserves the sparsity of the OT plan.
268264
#
265+
# Solve the Regularized OT problem with user-defined regularization
266+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
267+
#
268+
269+
270+
# Define a custom regularization function
271+
def f(G):
272+
return 0.5 * np.sum(G**2)
273+
274+
275+
def df(G):
276+
return G
277+
278+
279+
P_reg = ot.solve_sample(x1, x2, a, b, reg=1e2, reg_type=(f, df)).plan
280+
281+
282+
# sphinx_gallery_start_ignore
283+
pl.figure(1, (3, 3))
284+
plot2D_samples_mat(x1, x2, P_reg)
285+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
286+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
287+
pl.title("Custom reg plan")
288+
pl.show()
289+
# sphinx_gallery_end_ignore
290+
# %%
291+
#
292+
# .. note::
293+
# The examples above use the new API of POT. The old API is still available
294+
# and and the entropic OT plan and loss can be computed with the
295+
# :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below:
296+
#
297+
# .. code-block:: python
298+
#
299+
# Gs = ot.sinkhorn(a, b, C, reg=1e-1)
300+
# loss_sink = ot.sinkhorn2(a, b, C, reg=1e-1)
301+
#
302+
# For quadratic regularization, the :func:`ot.smooth.smooth_ot_dual` function
303+
# can be used to compute the solution of the regularized OT problem. For
304+
# user-defined regularization, the :func:`ot.optim.cg` function can be used
305+
# to solve the regularized OT problem with Conditional Gradient algorithm.
306+
#
307+
# Unbalanced and Partial Optimal Transport
308+
# ----------------------------
309+
#
310+
# Solve the Unbalanced OT problem
311+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
312+
#
313+
# Unbalanced OT relaxes the marginal constraints and allows for the source and
314+
# target total weights to be different. The :func:`ot.solve_sample` function can be
315+
# used to solve the unbalanced OT problem by setting the marginal penalization
316+
# :code:`unbalanced` parameter to a positive value.
317+
#
318+
319+
# Solve the unbalanced OT problem with KL penalization
320+
P_unb_kl = ot.solve_sample(x1, x2, a, b, unbalanced=5e-2).plan
321+
322+
# Unbalanced with KL penalization ad KL regularization
323+
P_unb_kl_reg = ot.solve_sample(
324+
x1, x2, a, b, unbalanced=5e-2, reg=1e-1
325+
).plan # also regularized
326+
327+
# Unbalanced with L2 penalization
328+
P_unb_l2 = ot.solve_sample(x1, x2, a, b, unbalanced=7e1, unbalanced_type="L2").plan
329+
330+
# sphinx_gallery_start_ignore
331+
pl.figure(1, (9, 3))
332+
333+
pl.subplot(1, 3, 1)
334+
plot2D_samples_mat(x1, x2, P_unb_kl)
335+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
336+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
337+
pl.title("Unbalanced KL plan")
338+
339+
pl.subplot(1, 3, 2)
340+
plot2D_samples_mat(x1, x2, P_unb_kl_reg)
341+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
342+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
343+
pl.title("Unbalanced KL + reg plan")
344+
345+
pl.subplot(1, 3, 3)
346+
plot2D_samples_mat(x1, x2, P_unb_l2)
347+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
348+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
349+
pl.title("Unbalanced L2 plan")
350+
351+
pl.show()
352+
# sphinx_gallery_end_ignore
353+
# %%
354+
# .. note::
355+
# Solving the unbalanced OT problem with the old API can be done with the
356+
# :func:`ot.unbalanced.sinkhorn_unbalanced` function as below:
357+
#
358+
# .. code-block:: python
359+
#
360+
# G_unb_kl = ot.unbalanced.sinkhorn_unbalanced(a, b, C, eps=reg, alpha=unbalanced)
361+
#
362+
# Partial Optimal Transport
363+
# ~~~~~~~~~~~~~~~~~~~~~~~~~
364+
#
365+
366+
# Solve the Unbalanced OT problem with TV penalization (equivalent)
367+
P_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type="TV").plan
368+
369+
# Solve the Partial OT problem with mass constraints (only old API)
370+
P_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5) # 50% mass transported
371+
372+
# sphinx_gallery_start_ignore
373+
pl.figure(1, (6, 3))
374+
375+
pl.subplot(1, 2, 1)
376+
plot2D_samples_mat(x1, x2, P_part_pen)
377+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
378+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
379+
pl.title("Partial (Unb. TV) plan")
380+
381+
pl.subplot(1, 2, 2)
382+
plot2D_samples_mat(x1, x2, P_part_const)
383+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
384+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
385+
pl.title("Partial 50% mass plan")
386+
pl.show()
387+
388+
# sphinx_gallery_end_ignore
389+
# %%
390+
#
391+
# Gromov-Wasserstein and Fused Gromov-Wasserstein
392+
# -----------------------------------------------
393+
#
394+
# Solve the Gromov-Wasserstein problem
395+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
396+
#
397+
# The Gromov-Wasserstein distance is a similarity measure between metric
398+
# measure spaces. So it does not require the samples to be in the same space.
399+
#
400+
401+
# Define the metric cost matrices in each spaces
402+
403+
C1 = ot.dist(x1, x1, metric="sqeuclidean")
404+
C2 = ot.dist(x2, x2, metric="sqeuclidean")
405+
406+
C1 /= C1.max()
407+
C2 /= C2.max()
408+
409+
# Solve the Gromov-Wasserstein problem
410+
sol_gw = ot.solve_gromov(C1, C2, a=a, b=b)
411+
P_gw = sol_gw.plan
412+
loss_gw = sol_gw.value
413+
loss_gw_linear = sol_gw.value_linear # linear part of loss
414+
loss_gw_quad = sol_gw.value_quad # quadratic part of loss
415+
416+
# Solve the Entropic Gromov-Wasserstein problem
417+
P_egw = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2).plan
418+
419+
# sphinx_gallery_start_ignore
420+
pl.figure(1, (6, 3))
421+
422+
pl.subplot(1, 2, 1)
423+
plot2D_samples_mat(x1, x2, P_gw)
424+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
425+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
426+
pl.title("GW plan")
427+
428+
pl.subplot(1, 2, 2)
429+
plot2D_samples_mat(x1, x2, P_egw)
430+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
431+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
432+
pl.title("Entropic GW plan")
433+
pl.show()
434+
# sphinx_gallery_end_ignore
435+
# %%
436+
# .. note::
437+
# The Gromov-Wasserstein problem can be solved with the old API using the
438+
# :func:`ot.gromov.gromov_wasserstein` function and the Entropic
439+
# Gromov-Wasserstein problem can be solved with the
440+
# :func:`ot.gromov.entropic_gromov_wasserstein` function.
441+
#
442+
# .. code-block:: python
443+
#
444+
# P_gw = ot.gromov.gromov_wasserstein(C1, C2, a, b)
445+
# P_egw = ot.gromov.entropic_gromov_wasserstein(C1, C2, a, b, epsilon=reg)
446+
#
447+
# loss_gw = ot.gromov.gromov_wasserstein2(C1, C2, a, b)
448+
# loss_egw = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a, b, epsilon=reg)
449+
#
450+
# Fused Gromov-Wasserstein
451+
# ~~~~~~~~~~~~~~~~~~~~~~~~~
452+
#
453+
454+
# Cost matrix
455+
M = C / np.max(C)
456+
457+
# Solve FGW problem with alpha=0.1
458+
P_fgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1).plan # C is cost across spaces
459+
460+
# SOlve entropic FGW problem with alpha=0.1
461+
P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan
462+
463+
# sphinx_gallery_start_ignore
464+
pl.figure(1, (6, 3))
465+
466+
pl.subplot(1, 2, 1)
467+
plot2D_samples_mat(x1, x2, P_fgw)
468+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
469+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
470+
pl.title("FGW plan")
471+
472+
pl.subplot(1, 2, 2)
473+
plot2D_samples_mat(x1, x2, P_efgw)
474+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
475+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
476+
pl.title("Entropic FGW plan")
477+
pl.show()
478+
479+
# sphinx_gallery_end_ignore
480+
# %%
481+
# .. note::
482+
# The Fused Gromov-Wasserstein problem can be solved with the old API using
483+
# the :func:`ot.gromov.fused_gromov_wasserstein` function and the Entropic
484+
# Fused Gromov-Wasserstein problem can be solved with the
485+
# :func:`ot.gromov.entropic_fused_gromov_wasserstein` function.
486+
#
487+
# .. code-block:: python
488+
#
489+
# P_fgw = ot.gromov.fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1)
490+
# P_efgw = ot.gromov.entropic_fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1, epsilon=reg)
491+
#
492+
# loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1)
493+
# loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg)
494+
#
495+
# Unbalanced Gromov-Wasserstein
496+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
497+
#
498+
499+
# # Solve the Unbalanced Gromov-Wasserstein problem
500+
# P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan
501+
502+
# # Solve the Unbalanced Entropic Gromov-Wasserstein problem
503+
# P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan
504+
505+
# # sphinx_gallery_start_ignore
506+
# pl.figure(1, (6, 3))
507+
508+
# pl.subplot(1, 2, 1)
509+
# plot2D_samples_mat(x1, x2, P_gw_unb)
510+
# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
511+
# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
512+
# pl.title("Unbalanced GW plan")
513+
514+
# pl.subplot(1, 2, 2)
515+
# plot2D_samples_mat(x1, x2, P_egw_unb)
516+
# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
517+
# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
518+
# pl.title("Unbalanced Entropic GW plan")
519+
# pl.show()
520+
# # sphinx_gallery_end_ignore

0 commit comments

Comments
 (0)