|
31 | 31 |
|
32 | 32 |
|
33 | 33 | # %%
|
34 |
| -# Example data |
| 34 | +# Data generation |
35 | 35 | # --------------
|
36 | 36 | #
|
37 |
| -# Data generation |
38 |
| -# ~~~~~~~~~~~~~~~ |
| 37 | +# We first generate two sets of samples in 2D that 25 and 50 |
| 38 | +# samples respectively located on circles. The weights of the samples are |
| 39 | +# uniform. |
39 | 40 |
|
40 | 41 | # Problem size
|
41 | 42 | n1 = 25
|
|
64 | 65 | # sphinx_gallery_end_ignore
|
65 | 66 |
|
66 | 67 | # %%
|
67 |
| -# We illustrate above the simple example of two 2D distributions with 25 and 50 |
68 |
| -# samples respectively located on circles. The weights of the samples are |
69 |
| -# uniform. |
70 | 68 | #
|
71 | 69 | # Solving exact Optimal Transport
|
72 | 70 | # -------------------------------
|
|
189 | 187 | # P = ot.emd(a, b, C)
|
190 | 188 | # loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
|
191 | 189 | #
|
192 |
| -# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample |
193 |
| -# |
194 | 190 |
|
195 | 191 |
|
196 | 192 | # %%
|
197 | 193 | # Sinkhorn and Regularized OT
|
198 | 194 | # ---------------------------
|
199 | 195 | #
|
200 |
| -# Solve Entropic Regularized OT with Sinkhorn algorithm |
| 196 | +# Entropic OT with Sinkhorn algorithm |
201 | 197 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
202 | 198 | #
|
203 | 199 |
|
|
230 | 226 | # exact OT problem and the OT plan is not sparse.
|
231 | 227 |
|
232 | 228 | # %%
|
233 |
| -# Solve the Regularized OT problem with other regularizations |
234 |
| -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 229 | +# Quadratic Regularized OT |
| 230 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~ |
235 | 231 | #
|
236 | 232 |
|
237 | 233 | # Use quadratic regularization
|
|
266 | 262 | # quadratic regularization is another common choice for regularized OT and
|
267 | 263 | # preserves the sparsity of the OT plan.
|
268 | 264 | #
|
| 265 | +# Solve the Regularized OT problem with user-defined regularization |
| 266 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 267 | +# |
| 268 | + |
| 269 | + |
| 270 | +# Define a custom regularization function |
| 271 | +def f(G): |
| 272 | + return 0.5 * np.sum(G**2) |
| 273 | + |
| 274 | + |
| 275 | +def df(G): |
| 276 | + return G |
| 277 | + |
| 278 | + |
| 279 | +P_reg = ot.solve_sample(x1, x2, a, b, reg=1e2, reg_type=(f, df)).plan |
| 280 | + |
| 281 | + |
| 282 | +# sphinx_gallery_start_ignore |
| 283 | +pl.figure(1, (3, 3)) |
| 284 | +plot2D_samples_mat(x1, x2, P_reg) |
| 285 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 286 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 287 | +pl.title("Custom reg plan") |
| 288 | +pl.show() |
| 289 | +# sphinx_gallery_end_ignore |
| 290 | +# %% |
| 291 | +# |
| 292 | +# .. note:: |
| 293 | +# The examples above use the new API of POT. The old API is still available |
| 294 | +# and and the entropic OT plan and loss can be computed with the |
| 295 | +# :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below: |
| 296 | +# |
| 297 | +# .. code-block:: python |
| 298 | +# |
| 299 | +# Gs = ot.sinkhorn(a, b, C, reg=1e-1) |
| 300 | +# loss_sink = ot.sinkhorn2(a, b, C, reg=1e-1) |
| 301 | +# |
| 302 | +# For quadratic regularization, the :func:`ot.smooth.smooth_ot_dual` function |
| 303 | +# can be used to compute the solution of the regularized OT problem. For |
| 304 | +# user-defined regularization, the :func:`ot.optim.cg` function can be used |
| 305 | +# to solve the regularized OT problem with Conditional Gradient algorithm. |
| 306 | +# |
| 307 | +# Unbalanced and Partial Optimal Transport |
| 308 | +# ---------------------------- |
| 309 | +# |
| 310 | +# Solve the Unbalanced OT problem |
| 311 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 312 | +# |
| 313 | +# Unbalanced OT relaxes the marginal constraints and allows for the source and |
| 314 | +# target total weights to be different. The :func:`ot.solve_sample` function can be |
| 315 | +# used to solve the unbalanced OT problem by setting the marginal penalization |
| 316 | +# :code:`unbalanced` parameter to a positive value. |
| 317 | +# |
| 318 | + |
| 319 | +# Solve the unbalanced OT problem with KL penalization |
| 320 | +P_unb_kl = ot.solve_sample(x1, x2, a, b, unbalanced=5e-2).plan |
| 321 | + |
| 322 | +# Unbalanced with KL penalization ad KL regularization |
| 323 | +P_unb_kl_reg = ot.solve_sample( |
| 324 | + x1, x2, a, b, unbalanced=5e-2, reg=1e-1 |
| 325 | +).plan # also regularized |
| 326 | + |
| 327 | +# Unbalanced with L2 penalization |
| 328 | +P_unb_l2 = ot.solve_sample(x1, x2, a, b, unbalanced=7e1, unbalanced_type="L2").plan |
| 329 | + |
| 330 | +# sphinx_gallery_start_ignore |
| 331 | +pl.figure(1, (9, 3)) |
| 332 | + |
| 333 | +pl.subplot(1, 3, 1) |
| 334 | +plot2D_samples_mat(x1, x2, P_unb_kl) |
| 335 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 336 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 337 | +pl.title("Unbalanced KL plan") |
| 338 | + |
| 339 | +pl.subplot(1, 3, 2) |
| 340 | +plot2D_samples_mat(x1, x2, P_unb_kl_reg) |
| 341 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 342 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 343 | +pl.title("Unbalanced KL + reg plan") |
| 344 | + |
| 345 | +pl.subplot(1, 3, 3) |
| 346 | +plot2D_samples_mat(x1, x2, P_unb_l2) |
| 347 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 348 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 349 | +pl.title("Unbalanced L2 plan") |
| 350 | + |
| 351 | +pl.show() |
| 352 | +# sphinx_gallery_end_ignore |
| 353 | +# %% |
| 354 | +# .. note:: |
| 355 | +# Solving the unbalanced OT problem with the old API can be done with the |
| 356 | +# :func:`ot.unbalanced.sinkhorn_unbalanced` function as below: |
| 357 | +# |
| 358 | +# .. code-block:: python |
| 359 | +# |
| 360 | +# G_unb_kl = ot.unbalanced.sinkhorn_unbalanced(a, b, C, eps=reg, alpha=unbalanced) |
| 361 | +# |
| 362 | +# Partial Optimal Transport |
| 363 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 364 | +# |
| 365 | + |
| 366 | +# Solve the Unbalanced OT problem with TV penalization (equivalent) |
| 367 | +P_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type="TV").plan |
| 368 | + |
| 369 | +# Solve the Partial OT problem with mass constraints (only old API) |
| 370 | +P_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5) # 50% mass transported |
| 371 | + |
| 372 | +# sphinx_gallery_start_ignore |
| 373 | +pl.figure(1, (6, 3)) |
| 374 | + |
| 375 | +pl.subplot(1, 2, 1) |
| 376 | +plot2D_samples_mat(x1, x2, P_part_pen) |
| 377 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 378 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 379 | +pl.title("Partial (Unb. TV) plan") |
| 380 | + |
| 381 | +pl.subplot(1, 2, 2) |
| 382 | +plot2D_samples_mat(x1, x2, P_part_const) |
| 383 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 384 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 385 | +pl.title("Partial 50% mass plan") |
| 386 | +pl.show() |
| 387 | + |
| 388 | +# sphinx_gallery_end_ignore |
| 389 | +# %% |
| 390 | +# |
| 391 | +# Gromov-Wasserstein and Fused Gromov-Wasserstein |
| 392 | +# ----------------------------------------------- |
| 393 | +# |
| 394 | +# Solve the Gromov-Wasserstein problem |
| 395 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 396 | +# |
| 397 | +# The Gromov-Wasserstein distance is a similarity measure between metric |
| 398 | +# measure spaces. So it does not require the samples to be in the same space. |
| 399 | +# |
| 400 | + |
| 401 | +# Define the metric cost matrices in each spaces |
| 402 | + |
| 403 | +C1 = ot.dist(x1, x1, metric="sqeuclidean") |
| 404 | +C2 = ot.dist(x2, x2, metric="sqeuclidean") |
| 405 | + |
| 406 | +C1 /= C1.max() |
| 407 | +C2 /= C2.max() |
| 408 | + |
| 409 | +# Solve the Gromov-Wasserstein problem |
| 410 | +sol_gw = ot.solve_gromov(C1, C2, a=a, b=b) |
| 411 | +P_gw = sol_gw.plan |
| 412 | +loss_gw = sol_gw.value |
| 413 | +loss_gw_linear = sol_gw.value_linear # linear part of loss |
| 414 | +loss_gw_quad = sol_gw.value_quad # quadratic part of loss |
| 415 | + |
| 416 | +# Solve the Entropic Gromov-Wasserstein problem |
| 417 | +P_egw = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2).plan |
| 418 | + |
| 419 | +# sphinx_gallery_start_ignore |
| 420 | +pl.figure(1, (6, 3)) |
| 421 | + |
| 422 | +pl.subplot(1, 2, 1) |
| 423 | +plot2D_samples_mat(x1, x2, P_gw) |
| 424 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 425 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 426 | +pl.title("GW plan") |
| 427 | + |
| 428 | +pl.subplot(1, 2, 2) |
| 429 | +plot2D_samples_mat(x1, x2, P_egw) |
| 430 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 431 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 432 | +pl.title("Entropic GW plan") |
| 433 | +pl.show() |
| 434 | +# sphinx_gallery_end_ignore |
| 435 | +# %% |
| 436 | +# .. note:: |
| 437 | +# The Gromov-Wasserstein problem can be solved with the old API using the |
| 438 | +# :func:`ot.gromov.gromov_wasserstein` function and the Entropic |
| 439 | +# Gromov-Wasserstein problem can be solved with the |
| 440 | +# :func:`ot.gromov.entropic_gromov_wasserstein` function. |
| 441 | +# |
| 442 | +# .. code-block:: python |
| 443 | +# |
| 444 | +# P_gw = ot.gromov.gromov_wasserstein(C1, C2, a, b) |
| 445 | +# P_egw = ot.gromov.entropic_gromov_wasserstein(C1, C2, a, b, epsilon=reg) |
| 446 | +# |
| 447 | +# loss_gw = ot.gromov.gromov_wasserstein2(C1, C2, a, b) |
| 448 | +# loss_egw = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a, b, epsilon=reg) |
| 449 | +# |
| 450 | +# Fused Gromov-Wasserstein |
| 451 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 452 | +# |
| 453 | + |
| 454 | +# Cost matrix |
| 455 | +M = C / np.max(C) |
| 456 | + |
| 457 | +# Solve FGW problem with alpha=0.1 |
| 458 | +P_fgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1).plan # C is cost across spaces |
| 459 | + |
| 460 | +# SOlve entropic FGW problem with alpha=0.1 |
| 461 | +P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan |
| 462 | + |
| 463 | +# sphinx_gallery_start_ignore |
| 464 | +pl.figure(1, (6, 3)) |
| 465 | + |
| 466 | +pl.subplot(1, 2, 1) |
| 467 | +plot2D_samples_mat(x1, x2, P_fgw) |
| 468 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 469 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 470 | +pl.title("FGW plan") |
| 471 | + |
| 472 | +pl.subplot(1, 2, 2) |
| 473 | +plot2D_samples_mat(x1, x2, P_efgw) |
| 474 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 475 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 476 | +pl.title("Entropic FGW plan") |
| 477 | +pl.show() |
| 478 | + |
| 479 | +# sphinx_gallery_end_ignore |
| 480 | +# %% |
| 481 | +# .. note:: |
| 482 | +# The Fused Gromov-Wasserstein problem can be solved with the old API using |
| 483 | +# the :func:`ot.gromov.fused_gromov_wasserstein` function and the Entropic |
| 484 | +# Fused Gromov-Wasserstein problem can be solved with the |
| 485 | +# :func:`ot.gromov.entropic_fused_gromov_wasserstein` function. |
| 486 | +# |
| 487 | +# .. code-block:: python |
| 488 | +# |
| 489 | +# P_fgw = ot.gromov.fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1) |
| 490 | +# P_efgw = ot.gromov.entropic_fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1, epsilon=reg) |
| 491 | +# |
| 492 | +# loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1) |
| 493 | +# loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg) |
| 494 | +# |
| 495 | +# Unbalanced Gromov-Wasserstein |
| 496 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 497 | +# |
| 498 | + |
| 499 | +# # Solve the Unbalanced Gromov-Wasserstein problem |
| 500 | +# P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan |
| 501 | + |
| 502 | +# # Solve the Unbalanced Entropic Gromov-Wasserstein problem |
| 503 | +# P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan |
| 504 | + |
| 505 | +# # sphinx_gallery_start_ignore |
| 506 | +# pl.figure(1, (6, 3)) |
| 507 | + |
| 508 | +# pl.subplot(1, 2, 1) |
| 509 | +# plot2D_samples_mat(x1, x2, P_gw_unb) |
| 510 | +# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 511 | +# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 512 | +# pl.title("Unbalanced GW plan") |
| 513 | + |
| 514 | +# pl.subplot(1, 2, 2) |
| 515 | +# plot2D_samples_mat(x1, x2, P_egw_unb) |
| 516 | +# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 517 | +# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 518 | +# pl.title("Unbalanced Entropic GW plan") |
| 519 | +# pl.show() |
| 520 | +# # sphinx_gallery_end_ignore |
0 commit comments