Skip to content

Commit d25770c

Browse files
clbonetcedricvincentcuazrflamary
authored
[MRG] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters (#680)
* bw barycenter with batched sqrtm * BWGD for barycenters * sbwgd for barycenters * Test fixed_point vs gradient_descent * fix test bwgd * nx exp_bures * update doc * fix merge * doc exp bw * First tests stochastic + exp * exp_bures with einsum * type Id test * up test stochastic * test weights * Add BW distance with batchs * step size SGD BW Barycenter * batchable BW distance * RELEASES.md * precommit * Add ot.gaussian.bures * Add arg backend * up stop criteria sgd Gaussian barycenter * Fix release * fix doc * change API bw * up test bures_wasserstein_distance * up test bures_wasserstein_distance * up test bures_wasserstein_distance --------- Co-authored-by: Cédric Vincent-Cuaz <[email protected]> Co-authored-by: Rémi Flamary <[email protected]>
1 parent 79eb337 commit d25770c

8 files changed

+710
-80
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,7 @@ Artificial Intelligence.
390390
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
391391

392392
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
393+
394+
[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.
395+
396+
[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.

RELEASES.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
1010
- Implement projected gradient descent solvers for entropic partial FGW (PR #702)
1111
- Fix documentation in the module `ot.gaussian` (PR #718)
12+
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)
13+
- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680)
14+
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
15+
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
1216

1317
#### Closed issues
1418
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
@@ -44,7 +48,6 @@ This release also contains few bug fixes, concerning the support of any metric i
4448
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
4549
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
4650
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
47-
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)
4851

4952
#### Closed issues
5053
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)

ot/backend.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,7 @@ def solve(self, a, b):
13631363
return np.linalg.solve(a, b)
13641364

13651365
def trace(self, a):
1366-
return np.trace(a)
1366+
return np.einsum("...ii", a)
13671367

13681368
def inv(self, a):
13691369
return scipy.linalg.inv(a)
@@ -1776,7 +1776,7 @@ def solve(self, a, b):
17761776
return jnp.linalg.solve(a, b)
17771777

17781778
def trace(self, a):
1779-
return jnp.trace(a)
1779+
return jnp.diagonal(a, axis1=-2, axis2=-1).sum(-1)
17801780

17811781
def inv(self, a):
17821782
return jnp.linalg.inv(a)
@@ -2309,7 +2309,7 @@ def solve(self, a, b):
23092309
return torch.linalg.solve(a, b)
23102310

23112311
def trace(self, a):
2312-
return torch.trace(a)
2312+
return torch.diagonal(a, dim1=-2, dim2=-1).sum(-1)
23132313

23142314
def inv(self, a):
23152315
return torch.linalg.inv(a)
@@ -2723,7 +2723,7 @@ def solve(self, a, b):
27232723
return cp.linalg.solve(a, b)
27242724

27252725
def trace(self, a):
2726-
return cp.trace(a)
2726+
return cp.trace(a, axis1=-2, axis2=-1)
27272727

27282728
def inv(self, a):
27292729
return cp.linalg.inv(a)

0 commit comments

Comments
 (0)