Skip to content

Commit 3e8421e

Browse files
committed
ot bar doc
1 parent 6a3eab5 commit 3e8421e

File tree

2 files changed

+87
-23
lines changed

2 files changed

+87
-23
lines changed

examples/barycenters/plot_barycenter_generic_cost.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
77
This example illustrates the computation of an Optimal Transport for a ground
88
cost that is not a power of a norm. We take the example of ground costs
9-
:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
9+
:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
1010
projection onto a circle k. This is an example of the fixed-point barycenter
11-
solver introduced in [74] which generalises [20].
11+
solver introduced in [74] which generalises [20] and [43].
1212
1313
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
1414
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
@@ -22,6 +22,8 @@
2222
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
2323
Conference in Machine Learning
2424
25+
[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
26+
2527
"""
2628

2729
# Author: Eloi Tanguy <[email protected]>
@@ -147,8 +149,8 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
147149
b_list,
148150
cost_list,
149151
B,
150-
max_its=fixed_point_its,
151-
stop_threshold=stop_threshold,
152+
numItermax=fixed_point_its,
153+
stopThr=stop_threshold,
152154
)
153155

154156
# %% Plot Barycenter (Iteration 10)

ot/lp/_barycenter_solvers.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception):
430430

431431
def free_support_barycenter_generic_costs(
432432
X_init,
433-
Y_list,
434-
b_list,
433+
measure_locations,
434+
measure_weights,
435435
cost_list,
436436
B,
437-
max_its=5,
438-
stop_threshold=1e-5,
437+
numItermax=5,
438+
stopThr=1e-5,
439439
log=False,
440440
):
441-
"""
442-
Solves the OT barycenter problem using the fixed point algorithm, iterating
443-
the function B on plans between the current barycentre and the measures.
441+
r"""
442+
Solves the OT barycenter problem for generic costs using the fixed point
443+
algorithm, iterating the ground barycenter function B on transport plans
444+
between the current barycentre and the measures.
445+
446+
The problem finds an optimal barycenter support `X` of given size (n, d)
447+
(enforced by the initialisation), minimising a sum of pairwise transport
448+
costs for the costs :math:`c_k`:
449+
450+
.. math::
451+
\min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k),
452+
453+
where:
454+
455+
- :math:`X` (n, d) is the barycentre support,
456+
- :math:`a` (n) is the (fixed) barycentre weights,
457+
- :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
458+
- :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
459+
- :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
460+
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
461+
462+
.. math::
463+
\mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
464+
465+
s.t. \ \pi \mathbf{1} = \mathbf{a}
466+
467+
\pi^T \mathbf{1} = \mathbf{b_k}
468+
469+
\pi \geq 0
470+
471+
in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
472+
c_k(X, Y_k))`.
473+
474+
The algorithm requires a given ground barycentre function `B` which computes
475+
a solution of the following minimisation problem given :math:`(y_1, \cdots,
476+
y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
477+
478+
.. math::
479+
B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
480+
481+
where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points
482+
:math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
483+
\cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
484+
this function, and for certain costs it can be computed explicitly of
485+
through a numerical solver.
486+
487+
This function implements [74] Algorithm 2, which generalises [20] and [43]
488+
to general costs and includes convergence guarantees, including for discrete measures.
444489
445490
Parameters
446491
----------
447492
X_init : array-like
448493
Array of shape (n, d) representing initial barycentre points.
449-
Y_list : list of array-like
494+
measure_locations : list of array-like
450495
List of K arrays of measure positions, each of shape (m_k, d_k).
451-
b_list : list of array-like
496+
measure_weights : list of array-like
452497
List of K arrays of measure weights, each of shape (m_k).
453498
cost_list : list of callable
454-
List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
499+
List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
455500
B : callable
456-
Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
457-
max_its : int, optional
501+
Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
502+
numItermax : int, optional
458503
Maximum number of iterations (default is 5).
459-
stop_threshold : float, optional
504+
stopThr : float, optional
460505
If the iterations move less than this, terminate (default is 1e-5).
461506
log : bool, optional
462507
Whether to return the log dictionary (default is False).
@@ -468,9 +513,25 @@ def free_support_barycenter_generic_costs(
468513
log_dict : list of array-like, optional
469514
log containing the exit status, list of iterations and list of
470515
displacements if log is True.
516+
517+
.. _references-free-support-barycenter-generic-costs:
518+
519+
References
520+
----------
521+
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
522+
523+
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
524+
525+
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
526+
527+
See Also
528+
--------
529+
ot.lp.free_support_barycenter : Free support solver for the case where
530+
:math:`c_k(x,y) = \|x-y\|_2^2`.
531+
ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
471532
"""
472-
nx = get_backend(X_init, Y_list[0])
473-
K = len(Y_list)
533+
nx = get_backend(X_init, measure_locations[0])
534+
K = len(measure_locations)
474535
n = X_init.shape[0]
475536
a = nx.ones(n) / n
476537
X_list = [X_init] if log else [] # store the iterations
@@ -479,13 +540,14 @@ def free_support_barycenter_generic_costs(
479540
exit_status = "Unknown"
480541

481542
try:
482-
for _ in range(max_its):
543+
for _ in range(numItermax):
483544
pi_list = [ # compute the pairwise transport plans
484-
emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K)
545+
emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
546+
for k in range(K)
485547
]
486548
Y_perm = []
487549
for k in range(K): # compute barycentric projections
488-
Y_perm.append(n * pi_list[k] @ Y_list[k])
550+
Y_perm.append(n * pi_list[k] @ measure_locations[k])
489551
X_next = B(Y_perm)
490552

491553
if log:
@@ -498,7 +560,7 @@ def free_support_barycenter_generic_costs(
498560
if log:
499561
dX_list.append(dX)
500562

501-
if dX < stop_threshold:
563+
if dX < stopThr:
502564
exit_status = "Stationary Point"
503565
raise StoppingCriterionReached
504566

0 commit comments

Comments
 (0)