You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: examples/barycenters/plot_barycenter_generic_cost.py
+6-4Lines changed: 6 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -6,9 +6,9 @@
6
6
7
7
This example illustrates the computation of an Optimal Transport for a ground
8
8
cost that is not a power of a norm. We take the example of ground costs
9
-
:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
9
+
:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
10
10
projection onto a circle k. This is an example of the fixed-point barycenter
11
-
solver introduced in [74] which generalises [20].
11
+
solver introduced in [74] which generalises [20] and [43].
12
12
13
13
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
14
14
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
@@ -22,6 +22,8 @@
22
22
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
23
23
Conference in Machine Learning
24
24
25
+
[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
Copy file name to clipboardExpand all lines: ot/lp/_barycenter_solvers.py
+81-19Lines changed: 81 additions & 19 deletions
Original file line number
Diff line number
Diff line change
@@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception):
430
430
431
431
deffree_support_barycenter_generic_costs(
432
432
X_init,
433
-
Y_list,
434
-
b_list,
433
+
measure_locations,
434
+
measure_weights,
435
435
cost_list,
436
436
B,
437
-
max_its=5,
438
-
stop_threshold=1e-5,
437
+
numItermax=5,
438
+
stopThr=1e-5,
439
439
log=False,
440
440
):
441
-
"""
442
-
Solves the OT barycenter problem using the fixed point algorithm, iterating
443
-
the function B on plans between the current barycentre and the measures.
441
+
r"""
442
+
Solves the OT barycenter problem for generic costs using the fixed point
443
+
algorithm, iterating the ground barycenter function B on transport plans
444
+
between the current barycentre and the measures.
445
+
446
+
The problem finds an optimal barycenter support `X` of given size (n, d)
447
+
(enforced by the initialisation), minimising a sum of pairwise transport
448
+
costs for the costs :math:`c_k`:
449
+
450
+
.. math::
451
+
\min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k),
452
+
453
+
where:
454
+
455
+
- :math:`X` (n, d) is the barycentre support,
456
+
- :math:`a` (n) is the (fixed) barycentre weights,
457
+
- :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
458
+
- :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
459
+
- :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
460
+
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points
482
+
:math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
483
+
\cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
484
+
this function, and for certain costs it can be computed explicitly of
485
+
through a numerical solver.
486
+
487
+
This function implements [74] Algorithm 2, which generalises [20] and [43]
488
+
to general costs and includes convergence guarantees, including for discrete measures.
444
489
445
490
Parameters
446
491
----------
447
492
X_init : array-like
448
493
Array of shape (n, d) representing initial barycentre points.
449
-
Y_list : list of array-like
494
+
measure_locations : list of array-like
450
495
List of K arrays of measure positions, each of shape (m_k, d_k).
451
-
b_list : list of array-like
496
+
measure_weights : list of array-like
452
497
List of K arrays of measure weights, each of shape (m_k).
453
498
cost_list : list of callable
454
-
List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
499
+
List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
455
500
B : callable
456
-
Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
457
-
max_its : int, optional
501
+
Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
502
+
numItermax : int, optional
458
503
Maximum number of iterations (default is 5).
459
-
stop_threshold : float, optional
504
+
stopThr : float, optional
460
505
If the iterations move less than this, terminate (default is 1e-5).
461
506
log : bool, optional
462
507
Whether to return the log dictionary (default is False).
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
522
+
523
+
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
524
+
525
+
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
526
+
527
+
See Also
528
+
--------
529
+
ot.lp.free_support_barycenter : Free support solver for the case where
530
+
:math:`c_k(x,y) = \|x-y\|_2^2`.
531
+
ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
471
532
"""
472
-
nx=get_backend(X_init, Y_list[0])
473
-
K=len(Y_list)
533
+
nx=get_backend(X_init, measure_locations[0])
534
+
K=len(measure_locations)
474
535
n=X_init.shape[0]
475
536
a=nx.ones(n) /n
476
537
X_list= [X_init] iflogelse [] # store the iterations
0 commit comments