Skip to content

Commit ccf608a

Browse files
committed
doc fixes + ot bar coverage
1 parent 3e8421e commit ccf608a

File tree

3 files changed

+161
-41
lines changed

3 files changed

+161
-41
lines changed

examples/barycenters/plot_barycenter_generic_cost.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
projection onto a circle k. This is an example of the fixed-point barycenter
1111
solver introduced in [74] which generalises [20] and [43].
1212
13-
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
14-
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
13+
The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
14+
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
1515
:math:`x` with Pytorch.
1616
17-
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
18-
Barycentres of Measures for Generic Transport
19-
Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
17+
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
18+
Barycentres of Measures for Generic Transport Costs.
19+
arXiv preprint 2501.04016 (2024)
2020
21-
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein
22-
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
23-
Conference in Machine Learning
21+
[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
22+
Barycenters. InternationalConference in Machine Learning
2423
25-
[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
24+
[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in
25+
Wasserstein space. Journal of Mathematical Analysis and Applications 441.2
26+
(2016): 744-762.
2627
2728
"""
2829

@@ -32,7 +33,8 @@
3233

3334
# sphinx_gallery_thumbnail_number = 1
3435

35-
# %% Generate data
36+
# %%
37+
# Generate data
3638
import torch
3739
from torch.optim import Adam
3840
from ot.utils import dist
@@ -43,7 +45,7 @@
4345

4446
torch.manual_seed(42)
4547

46-
n = 100 # number of points of the of the barycentre
48+
n = 200 # number of points of the of the barycentre
4749
d = 2 # dimensions of the original measure
4850
K = 4 # number of measures to barycentre
4951
m = 50 # number of points of the measures
@@ -82,7 +84,8 @@ def proj_circle(X, origin, radius):
8284
Y_list.append(P_list[k](X_temp))
8385

8486

85-
# %% Define costs and ground barycenter function
87+
# %%
88+
# Define costs and ground barycenter function
8689
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
8790
# (n, n_k) matrix of costs
8891
def c1(x, y):
@@ -140,25 +143,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
140143
return x
141144

142145

143-
# %% Compute the barycenter measure
144-
fixed_point_its = 10
146+
# %%
147+
# Compute the barycenter measure
148+
fixed_point_its = 3
145149
X_init = torch.rand(n, d)
146150
X_bar = free_support_barycenter_generic_costs(
147-
X_init,
148151
Y_list,
149152
b_list,
153+
X_init,
150154
cost_list,
151155
B,
152156
numItermax=fixed_point_its,
153157
stopThr=stop_threshold,
154158
)
155159

156-
# %% Plot Barycenter (Iteration 10)
157-
alpha = 0.5
160+
# %%
161+
# Plot Barycenter (Iteration 3)
162+
alpha = 0.4
163+
s = 80
158164
labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
159165
for Y, label in zip(Y_list, labels):
160-
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label)
161-
plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha)
166+
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s)
167+
plt.scatter(
168+
*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s
169+
)
162170
plt.axis("equal")
163171
plt.xlim(-0.3, 1.3)
164172
plt.ylim(-0.3, 1.3)

ot/lp/_barycenter_solvers.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -429,19 +429,20 @@ class StoppingCriterionReached(Exception):
429429

430430

431431
def free_support_barycenter_generic_costs(
432-
X_init,
433432
measure_locations,
434433
measure_weights,
434+
X_init,
435435
cost_list,
436436
B,
437+
a=None,
437438
numItermax=5,
438439
stopThr=1e-5,
439440
log=False,
440441
):
441442
r"""
442443
Solves the OT barycenter problem for generic costs using the fixed point
443444
algorithm, iterating the ground barycenter function B on transport plans
444-
between the current barycentre and the measures.
445+
between the current barycenter and the measures.
445446
446447
The problem finds an optimal barycenter support `X` of given size (n, d)
447448
(enforced by the initialisation), minimising a sum of pairwise transport
@@ -452,12 +453,13 @@ def free_support_barycenter_generic_costs(
452453
453454
where:
454455
455-
- :math:`X` (n, d) is the barycentre support,
456-
- :math:`a` (n) is the (fixed) barycentre weights,
457-
- :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
456+
- :math:`X` (n, d) is the barycenter support,
457+
- :math:`a` (n) is the (fixed) barycenter weights,
458+
- :math:`Y_k` (m_k, d_k) is the k-th measure support
459+
(`measure_locations[k]`),
458460
- :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
459461
- :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
460-
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
462+
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`:
461463
462464
.. math::
463465
\mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
@@ -471,9 +473,10 @@ def free_support_barycenter_generic_costs(
471473
in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
472474
c_k(X, Y_k))`.
473475
474-
The algorithm requires a given ground barycentre function `B` which computes
475-
a solution of the following minimisation problem given :math:`(y_1, \cdots,
476-
y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
476+
The algorithm requires a given ground barycenter function `B` which computes
477+
(broadcasted of `n`) solutions of the following minimisation problem given
478+
:math:`(Y_1, \cdots, Y_K) \in
479+
\mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`:
477480
478481
.. math::
479482
B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
@@ -482,23 +485,32 @@ def free_support_barycenter_generic_costs(
482485
:math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
483486
\cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
484487
this function, and for certain costs it can be computed explicitly of
485-
through a numerical solver.
488+
through a numerical solver. The input function B takes a list of K arrays of
489+
shape (n, d_k) and returns an array of shape (n, d).
486490
487491
This function implements [74] Algorithm 2, which generalises [20] and [43]
488-
to general costs and includes convergence guarantees, including for discrete measures.
492+
to general costs and includes convergence guarantees, including for discrete
493+
measures.
489494
490495
Parameters
491496
----------
492-
X_init : array-like
493-
Array of shape (n, d) representing initial barycentre points.
494497
measure_locations : list of array-like
495498
List of K arrays of measure positions, each of shape (m_k, d_k).
496499
measure_weights : list of array-like
497500
List of K arrays of measure weights, each of shape (m_k).
501+
X_init : array-like
502+
Array of shape (n, d) representing initial barycenter points.
498503
cost_list : list of callable
499-
List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
504+
List of K cost functions :math:`c_k: \mathbb{R}^{n\times
505+
d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
506+
m_k}`.
500507
B : callable
501-
Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
508+
Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
509+
of shape (n\times d_K), computing the ground barycenters (broadcasted
510+
over n).
511+
a : array-like, optional
512+
Array of shape (n,) representing weights of the barycenter
513+
measure.Defaults to uniform.
502514
numItermax : int, optional
503515
Maximum number of iterations (default is 5).
504516
stopThr : float, optional
@@ -509,7 +521,7 @@ def free_support_barycenter_generic_costs(
509521
Returns
510522
-------
511523
X : array-like
512-
Array of shape (n, d) representing barycentre points.
524+
Array of shape (n, d) representing barycenter points.
513525
log_dict : list of array-like, optional
514526
log containing the exit status, list of iterations and list of
515527
displacements if log is True.
@@ -518,22 +530,27 @@ def free_support_barycenter_generic_costs(
518530
519531
References
520532
----------
521-
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
533+
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
534+
barycenters of Measures for Generic Transport Costs. arXiv preprint
535+
2501.04016 (2024)
522536
523-
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
537+
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein
538+
barycenters." International Conference on Machine Learning. 2014.
524539
525-
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
540+
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to
541+
barycenters in Wasserstein space." Journal of Mathematical Analysis and
542+
Applications 441.2 (2016): 744-762.
526543
527544
See Also
528545
--------
529-
ot.lp.free_support_barycenter : Free support solver for the case where
530-
:math:`c_k(x,y) = \|x-y\|_2^2`.
546+
ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`.
531547
ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
532548
"""
533549
nx = get_backend(X_init, measure_locations[0])
534550
K = len(measure_locations)
535551
n = X_init.shape[0]
536-
a = nx.ones(n) / n
552+
if a is None:
553+
a = nx.ones(n, type_as=X_init) / n
537554
X_list = [X_init] if log else [] # store the iterations
538555
X = X_init
539556
dX_list = [] # store the displacement squared norms

test/test_ot.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from ot.datasets import make_1D_gauss as gauss
1414
from ot.backend import torch, tf
1515

16+
# import ot.lp._barycenter_solvers # TODO: remove this import
17+
1618

1719
def test_emd_dimension_and_mass_mismatch():
1820
# test emd and emd2 for dimension mismatch
@@ -395,6 +397,99 @@ def test_generalised_free_support_barycenter_backends(nx):
395397
np.testing.assert_allclose(Y, nx.to_numpy(Y2))
396398

397399

400+
def test_free_support_barycenter_generic_costs():
401+
measures_locations = [
402+
np.array([-1.0]).reshape((1, 1)),
403+
np.array([1.0]).reshape((1, 1)),
404+
]
405+
measures_weights = [np.array([1.0]), np.array([1.0])]
406+
407+
X_init = np.array([-12.0]).reshape((1, 1))
408+
409+
# obvious barycenter location between two Diracs
410+
bar_locations = np.array([0.0]).reshape((1, 1))
411+
412+
def cost(x, y):
413+
return ot.dist(x, y)
414+
415+
cost_list = [cost, cost]
416+
417+
def B(y):
418+
out = 0
419+
for yk in y:
420+
out += yk / len(y)
421+
return out
422+
423+
X = ot.lp.free_support_barycenter_generic_costs(
424+
measures_locations, measures_weights, X_init, cost_list, B
425+
)
426+
427+
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
428+
429+
# test with log and specific weights
430+
X2, log = ot.lp.free_support_barycenter_generic_costs(
431+
measures_locations,
432+
measures_weights,
433+
X_init,
434+
cost_list,
435+
B,
436+
a=ot.unif(1),
437+
log=True,
438+
)
439+
440+
assert "X_list" in log
441+
assert "exit_status" in log
442+
assert "dX_list" in log
443+
444+
np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7)
445+
446+
# test with one iteration for Max Iterations Reached
447+
X3, log2 = ot.lp.free_support_barycenter_generic_costs(
448+
measures_locations,
449+
measures_weights,
450+
X_init,
451+
cost_list,
452+
B,
453+
numItermax=1,
454+
log=True,
455+
)
456+
assert log2["exit_status"] == "Max iterations reached"
457+
458+
459+
def test_free_support_barycenter_generic_costs_backends(nx):
460+
measures_locations = [
461+
np.array([-1.0]).reshape((1, 1)),
462+
np.array([1.0]).reshape((1, 1)),
463+
]
464+
measures_weights = [np.array([1.0]), np.array([1.0])]
465+
X_init = np.array([-12.0]).reshape((1, 1))
466+
467+
def cost(x, y):
468+
return ot.dist(x, y)
469+
470+
cost_list = [cost, cost]
471+
472+
def B(y):
473+
out = 0
474+
for yk in y:
475+
out += yk / len(y)
476+
return out
477+
478+
X = ot.lp.free_support_barycenter_generic_costs(
479+
measures_locations, measures_weights, X_init, cost_list, B
480+
)
481+
482+
measures_locations2 = nx.from_numpy(*measures_locations)
483+
measures_weights2 = nx.from_numpy(*measures_weights)
484+
X_init2 = nx.from_numpy(X_init)
485+
486+
X2 = ot.lp.free_support_barycenter_generic_costs(
487+
measures_locations2, measures_weights2, X_init2, cost_list, B
488+
)
489+
490+
np.testing.assert_allclose(X, nx.to_numpy(X2))
491+
492+
398493
@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available")
399494
def test_lp_barycenter_cvxopt():
400495
a1 = np.array([1.0, 0, 0])[:, None]

0 commit comments

Comments
 (0)