Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters #680

Merged
merged 37 commits into from
Mar 12, 2025

Conversation

clbonet
Copy link
Contributor

@clbonet clbonet commented Oct 19, 2024

Types of changes

This PR aims to add the Bures-Wasserstein gradient descent solver to compute Bures-Wasserstein barycenters (see e.g. Gradient descent algorithms for Bures-Wasserstein barycenters or Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent).

  • Restructured ot.gaussian.bures_wasserstein_barycenter to allow to use different methods
  • Added the previous fixed-point algorithm in ot.gaussian.bures_barycenter_fixpoint
  • Added the Bures-Wasserstein gradient descent in ot.gaussian.bures_barycenter_gradient_descent
  • Added an iteration over the methods in the test test_bures_wasserstein_barycenter
  • Added a test test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter
  • Added batch version of ot.gaussian.bures_wasserstein_distance
  • Trace can be computed for batchs of matrices. The choices of the implementation of the trace are made following the runtimes reported by the following codes (on a CPU):
import torch
import jax
import numpy as np
import tensorflow as tf
import jax.numpy as jnp

A = np.random.rand(1000, 100, 100)

%timeit np.einsum("...ii", A) # 109 μs ± 1.71 μs per loop
%timeit np.trace(A, axis1=-2, axis2=-1) # 116 μs ± 1.79 μs
%timeit A.diagonal(axis1=-2, axis2=-1).sum(-1) # 114 μs ± 2.87 μs per loop

A = torch.rand(1000, 100, 100)

%timeit torch.einsum("...ii", A)  # 3.17 ms ± 1.1 ms per loop 
%timeit A.diagonal(dim1=-2, dim2=-1).sum(-1) # 3.1 ms ± 879 μs per loop

A = tf.random.uniform((1000, 100, 100))

@tf.function
def trace_sum(A):
    return tf.einsum("...ii", A)

@tf.function
def trace_sum_v2(A):
    return tf.reduce_sum(tf.linalg.diag_part(A), axis=-1)

# Warm-up execution
trace_sum(A)  
trace_sum_v2(A)

# Benchmarking
%timeit trace_sum(A) # 486 μs ± 21.1 μs per loop
%timeit trace_sum_v2(A) # 430 μs ± 36.1 μs per loop
%timeit tf.linalg.trace(A) # 404 μs ± 18.2 μs per loop

# For jax, the results might look different using jit
A = jnp.ones((1000, 100, 100)) 

%timeit jnp.einsum("...ii", A) # 13.6 ms ± 324 μs per loop
%timeit jax.vmap(jnp.trace)(A) # 12.1 ms ± 457 μs per loop
%timeit A.diagonal(axis1=-2, axis2=-1).sum(-1) # 1.64 ms ± 3.62 ms per loop

Motivation and context / Related issue

The Bures-Wasserstein gradient descent comes with convergence guarantees to solve Bures-Wasserstein barycenters. Moreover, it can also be used in a stochastic way when there are too much Gaussian. Thus, it is a good alternative to the fixed-point algorithm currently implemented.

How has this been tested (if it applies)

I added a test test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter to assess both methods returns the same barycenter. I also added the itertools to test_bures_wasserstein_barycenter.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comments. I will let @antoinecollas do a proper review he is the expert in Riemannian optimization

Copy link

codecov bot commented Oct 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.13%. Comparing base (79eb337) to head (1444648).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #680      +/-   ##
==========================================
+ Coverage   97.10%   97.13%   +0.03%     
==========================================
  Files         100      100              
  Lines       20115    20369     +254     
==========================================
+ Hits        19532    19786     +254     
  Misses        583      583              
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. A few tests especialy about errors are missing

@clbonet clbonet changed the title [WIP] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters [MRG] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters Mar 4, 2025
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few questions and then we can merge

@@ -8,6 +8,10 @@
- Automatic PR labeling and release file update check (PR #704)
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
- Fix documentation in the module `ot.gaussian` (PR #718)
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont' see that in the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmh, I think I did a mistake when merging with the master at some point. (It was deleted from Line 46 of the Releases.md, and it seemed to be in the wrong releases of POT)

@@ -1363,7 +1363,8 @@ def solve(self, a, b):
return np.linalg.solve(a, b)

def trace(self, a):
return np.trace(a)
return np.einsum("...ii", a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that faster or slower? we need an idea

ot/gaussian.py Outdated
Returns
-------
W : float
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too complicated API, do float if d, and for the rest use a parameter that return paireed or cross distances

@rflamary rflamary merged commit d25770c into PythonOT:master Mar 12, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants