Skip to content

Commit 081e4eb

Browse files
committed
added fixed-point barycenter function to ot.lp._barycenter_solvers_
1 parent d69bf97 commit 081e4eb

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ The contributors to this library are:
4444
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW,
4545
semi-relaxed FGW, quantized FGW, partial FGW)
4646
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein
47-
Barycenters, GMMOT)
47+
Barycenters, GMMOT, Barycenters for General Transport Costs)
4848
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
4949
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
5050
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,7 @@ Artificial Intelligence.
391391
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
392392

393393
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
394+
395+
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
396+
Barycentres of Measures for Generic Transport
397+
Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
88
- Automatic PR labeling and release file update check (PR #704)
99
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
10+
- Implement fixed-point solver for OT barycenters with generic cost functions
11+
(generalizes `ot.lp.free_support_barycenter`). (PR #715)
1012

1113
#### Closed issues
1214
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

ot/lp/_barycenter_solvers.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,90 @@ def generalized_free_support_barycenter(
422422
return Y, log_dict
423423
else:
424424
return Y
425+
426+
427+
class StoppingCriterionReached(Exception):
428+
pass
429+
430+
431+
def solve_OT_barycenter_fixed_point(
432+
X_init,
433+
Y_list,
434+
b_list,
435+
cost_list,
436+
B,
437+
max_its=5,
438+
stop_threshold=1e-5,
439+
log=False,
440+
):
441+
"""
442+
Solves the OT barycenter problem using the fixed point algorithm, iterating
443+
the function B on plans between the current barycentre and the measures.
444+
445+
Parameters
446+
----------
447+
X_init : array-like
448+
Array of shape (n, d) representing initial barycentre points.
449+
Y_list : list of array-like
450+
List of K arrays of measure positions, each of shape (m_k, d_k).
451+
b_list : list of array-like
452+
List of K arrays of measure weights, each of shape (m_k).
453+
cost_list : list of callable
454+
List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
455+
B : callable
456+
Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
457+
max_its : int, optional
458+
Maximum number of iterations (default is 5).
459+
stop_threshold : float, optional
460+
If the iterations move less than this, terminate (default is 1e-5).
461+
log : bool, optional
462+
Whether to return the log dictionary (default is False).
463+
464+
Returns
465+
-------
466+
X : array-like
467+
Array of shape (n, d) representing barycentre points.
468+
log_dict : list of array-like, optional
469+
log containing the exit status, list of iterations and list of
470+
displacements if log is True.
471+
"""
472+
nx = get_backend(X_init, Y_list[0])
473+
K = len(Y_list)
474+
n = X_init.shape[0]
475+
a = nx.ones(n) / n
476+
X_list = [X_init] if log else [] # store the iterations
477+
X = X_init
478+
dX_list = [] # store the displacement squared norms
479+
exit_status = "Unknown"
480+
481+
try:
482+
for _ in range(max_its):
483+
pi_list = [ # compute the pairwise transport plans
484+
emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K)
485+
]
486+
Y_perm = []
487+
for k in range(K): # compute barycentric projections
488+
Y_perm.append(n * pi_list[k] @ Y_list[k])
489+
X_next = B(Y_perm)
490+
491+
if log:
492+
X_list.append(X_next)
493+
494+
# stationary criterion: move less than the threshold
495+
dX = nx.sum((X - X_next) ** 2)
496+
X = X_next
497+
498+
if log:
499+
dX_list.append(dX)
500+
501+
if dX < stop_threshold:
502+
exit_status = "Stationary Point"
503+
raise StoppingCriterionReached
504+
505+
exit_status = "Max iterations reached"
506+
raise StoppingCriterionReached
507+
508+
except StoppingCriterionReached:
509+
if log:
510+
return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
511+
return X

0 commit comments

Comments
 (0)