-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] OT barycenters for generic transport costs #715
Open
eloitanguy
wants to merge
28
commits into
PythonOT:master
Choose a base branch
from
eloitanguy:dev
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
590e4d7
ot.lp reorganise to avoid def in __init__
eloitanguy 109edb7
pr number + enabled pre-commit
eloitanguy 0957904
added barycenter.py imports
eloitanguy 818b3e7
fixed wrong import in ot.gmm
eloitanguy 08c2285
ruff fix attempt
eloitanguy f268515
removed ot bar contribs -> only o.lp reorganisation in this PR
eloitanguy 8f24cb9
add check_number_threads to ot/lp/__init__.py __all__
eloitanguy 3e3b444
update releases
eloitanguy 566a0fc
made barycenter_solvers and network_simplex hidden + deprecated ot.lp…
eloitanguy 5c35d58
fix ref to lp.cvx in test
eloitanguy 8ffb061
lp.cvx now imports barycenter and gives a warnings.warning
eloitanguy 26748eb
cvx import barycenter
eloitanguy d69bf97
Merge branch 'PythonOT:master' into dev
eloitanguy 081e4eb
added fixed-point barycenter function to ot.lp._barycenter_solvers_
eloitanguy 5952019
ot bar demo
eloitanguy 6a3eab5
Merge branch 'master' into dev
rflamary 3e8421e
ot bar doc
eloitanguy ccf608a
doc fixes + ot bar coverage
eloitanguy 37b9c80
python 3.13 in test workflow + added ggmot barycenter (WIP)
eloitanguy a20d3f0
fixed github action file
eloitanguy 0b6217b
ot bar doc + test coverage
eloitanguy 21bf86b
examples: ot bar with projections onto circles + gmm ot bar
eloitanguy 0820e51
releases + readme + docs update
eloitanguy d1510ee
Merge branch 'master' into dev
eloitanguy 391ad39
Merge branch 'master' into dev
eloitanguy 6bd4af8
ref fix
eloitanguy 51722bf
implementation comments
eloitanguy 371e3e7
Merge branch 'master' into dev
eloitanguy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
177 changes: 177 additions & 0 deletions
177
examples/barycenters/plot_free_support_barycenter_generic_cost.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
===================================== | ||
OT Barycenter with Generic Costs Demo | ||
===================================== | ||
|
||
This example illustrates the computation of an Optimal Transport Barycenter for | ||
a ground cost that is not a power of a norm. We take the example of ground costs | ||
:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) | ||
projection onto a circle k. This is an example of the fixed-point barycenter | ||
solver introduced in [76] which generalises [20] and [43]. | ||
|
||
The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in | ||
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over | ||
:math:`x` with Pytorch. | ||
|
||
[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing | ||
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 | ||
(2024) | ||
|
||
[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein | ||
Barycenters. InternationalConference in Machine Learning | ||
|
||
[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in | ||
Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 | ||
(2016): 744-762. | ||
|
||
""" | ||
|
||
# Author: Eloi Tanguy <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
# sphinx_gallery_thumbnail_number = 1 | ||
|
||
# %% | ||
# Generate data | ||
import torch | ||
from torch.optim import Adam | ||
from ot.utils import dist | ||
import numpy as np | ||
from ot.lp import free_support_barycenter_generic_costs | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
torch.manual_seed(42) | ||
|
||
n = 200 # number of points of the of the barycentre | ||
d = 2 # dimensions of the original measure | ||
K = 4 # number of measures to barycentre | ||
m = 50 # number of points of the measures | ||
b_list = [torch.ones(m) / m] * K # weights of the 4 measures | ||
weights = torch.ones(K) / K # weights for the barycentre | ||
stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo | ||
|
||
|
||
# map R^2 -> R^2 projection onto circle | ||
def proj_circle(X, origin, radius): | ||
diffs = X - origin[None, :] | ||
norms = torch.norm(diffs, dim=1) | ||
return origin[None, :] + radius * diffs / norms[:, None] | ||
|
||
|
||
# circles on which to project | ||
origin1 = torch.tensor([-1.0, -1.0]) | ||
origin2 = torch.tensor([-1.0, 2.0]) | ||
origin3 = torch.tensor([2.0, 2.0]) | ||
origin4 = torch.tensor([2.0, -1.0]) | ||
r = np.sqrt(2) | ||
P_list = [ | ||
lambda X: proj_circle(X, origin1, r), | ||
lambda X: proj_circle(X, origin2, r), | ||
lambda X: proj_circle(X, origin3, r), | ||
lambda X: proj_circle(X, origin4, r), | ||
] | ||
|
||
# measures to barycentre are projections of different random circles | ||
# onto the K circles | ||
Y_list = [] | ||
for k in range(K): | ||
t = torch.rand(m) * 2 * np.pi | ||
X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) | ||
X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] | ||
Y_list.append(P_list[k](X_temp)) | ||
|
||
|
||
# %% | ||
# Define costs and ground barycenter function | ||
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a | ||
# (n, n_k) matrix of costs | ||
def c1(x, y): | ||
return dist(P_list[0](x), y) | ||
|
||
|
||
def c2(x, y): | ||
return dist(P_list[1](x), y) | ||
|
||
|
||
def c3(x, y): | ||
return dist(P_list[2](x), y) | ||
|
||
|
||
def c4(x, y): | ||
return dist(P_list[3](x), y) | ||
|
||
|
||
cost_list = [c1, c2, c3, c4] | ||
|
||
|
||
# batched total ground cost function for candidate points x (n, d) | ||
# for computation of the ground barycenter B with gradient descent | ||
def C(x, y): | ||
""" | ||
Computes the barycenter cost for candidate points x (n, d) and | ||
measure supports y: List(n, d_k). | ||
""" | ||
n = x.shape[0] | ||
K = len(y) | ||
out = torch.zeros(n) | ||
for k in range(K): | ||
out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) | ||
return out | ||
|
||
|
||
# ground barycenter function | ||
def B(y, its=150, lr=1, stop_threshold=stop_threshold): | ||
""" | ||
Computes the ground barycenter for measure supports y: List(n, d_k). | ||
Output: (n, d) array | ||
""" | ||
x = torch.randn(n, d) | ||
x.requires_grad_(True) | ||
opt = Adam([x], lr=lr) | ||
for _ in range(its): | ||
x_prev = x.data.clone() | ||
opt.zero_grad() | ||
loss = torch.sum(C(x, y)) | ||
loss.backward() | ||
opt.step() | ||
diff = torch.sum((x.data - x_prev) ** 2) | ||
if diff < stop_threshold: | ||
break | ||
return x | ||
|
||
|
||
# %% | ||
# Compute the barycenter measure | ||
fixed_point_its = 3 | ||
X_init = torch.rand(n, d) | ||
X_bar = free_support_barycenter_generic_costs( | ||
Y_list, | ||
b_list, | ||
X_init, | ||
cost_list, | ||
B, | ||
numItermax=fixed_point_its, | ||
stopThr=stop_threshold, | ||
) | ||
|
||
# %% | ||
# Plot Barycenter (Iteration 3) | ||
alpha = 0.4 | ||
s = 80 | ||
labels = ["circle 1", "circle 2", "circle 3", "circle 4"] | ||
for Y, label in zip(Y_list, labels): | ||
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) | ||
plt.scatter( | ||
*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s | ||
) | ||
plt.axis("equal") | ||
plt.xlim(-0.3, 1.3) | ||
plt.ylim(-0.3, 1.3) | ||
plt.axis("off") | ||
plt.legend() | ||
plt.tight_layout() | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
===================================== | ||
Gaussian Mixture Model OT Barycenters | ||
===================================== | ||
|
||
This example illustrates the computation of a barycenter between Gaussian | ||
Mixtures in the sense of GMM-OT [69]. This computation is done using the | ||
fixed-point method for OT barycenters with generic costs [76], for which POT | ||
provides a general solver, and a specific GMM solver. Note that this is a | ||
'free-support' method, implying that the number of components of the barycenter | ||
GMM and their weights are fixed. | ||
|
||
The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over | ||
the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the | ||
Bures-Wasserstein manifold), and to compute barycenters with respect to the | ||
2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a | ||
gaussian mixture is a finite combination of Diracs on specific gaussians, and | ||
two mixtures are compared with the 2-Wasserstein distance on this space, where | ||
ground cost the squared Bures distance between gaussians. | ||
|
||
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space | ||
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. | ||
|
||
[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing | ||
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 | ||
(2024) | ||
|
||
""" | ||
|
||
# Author: Eloi Tanguy <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
# sphinx_gallery_thumbnail_number = 1 | ||
|
||
# %% | ||
# Generate data | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from matplotlib.patches import Ellipse | ||
import ot | ||
from ot.gmm import gmm_barycenter_fixed_point | ||
|
||
|
||
K = 3 # number of GMMs | ||
d = 2 # dimension | ||
n = 6 # number of components of the desired barycenter | ||
|
||
|
||
def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): | ||
rng = np.random.RandomState(seed=seed) | ||
means = rng.randn(K, d) | ||
P = rng.randn(K, d, d) * cov_scale | ||
# C[k] = P[k] @ P[k]^T + min_cov_eig * I | ||
covariances = np.einsum("kab,kcb->kac", P, P) | ||
covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) | ||
weights = rng.random(K) | ||
weights /= np.sum(weights) | ||
return means, covariances, weights | ||
|
||
|
||
m_list = [5, 6, 7] # number of components in each GMM | ||
offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] | ||
means_list = [] # list of means for each GMM | ||
covs_list = [] # list of covariances for each GMM | ||
w_list = [] # list of weights for each GMM | ||
|
||
# generate GMMs | ||
for k in range(K): | ||
means, covs, b = get_random_gmm( | ||
m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 | ||
) | ||
means = means / 2 + offsets[k][None, :] | ||
means_list.append(means) | ||
covs_list.append(covs) | ||
w_list.append(b) | ||
|
||
# %% | ||
# Compute the barycenter using the fixed-point method | ||
init_means, init_covs, _ = get_random_gmm(n, d, seed=0) | ||
weights = ot.unif(K) # barycenter coefficients | ||
means_bar, covs_bar, log = gmm_barycenter_fixed_point( | ||
means_list, | ||
covs_list, | ||
w_list, | ||
init_means, | ||
init_covs, | ||
weights, | ||
iterations=3, | ||
log=True, | ||
) | ||
|
||
|
||
# %% | ||
# Define plotting functions | ||
|
||
|
||
# draw a covariance ellipse | ||
def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): | ||
def eigsorted(cov): | ||
vals, vecs = np.linalg.eigh(cov) | ||
order = vals.argsort()[::-1].copy() | ||
return vals[order], vecs[:, order] | ||
|
||
vals, vecs = eigsorted(C) | ||
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) | ||
w, h = 2 * nstd * np.sqrt(vals) | ||
ell = Ellipse( | ||
xy=(mu[0], mu[1]), | ||
width=w, | ||
height=h, | ||
alpha=alpha, | ||
angle=theta, | ||
facecolor=color, | ||
edgecolor=color, | ||
label=label, | ||
fill=True, | ||
) | ||
if ax is None: | ||
ax = plt.gca() | ||
ax.add_artist(ell) | ||
|
||
|
||
# draw a gmm as a set of ellipses with weights shown in alpha value | ||
def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): | ||
for k in range(ms.shape[0]): | ||
draw_cov( | ||
ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax | ||
) | ||
|
||
|
||
# %% | ||
# Plot the results | ||
fig, ax = plt.subplots(figsize=(6, 6)) | ||
axis = [-4, 4, -2, 6] | ||
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) | ||
for k in range(K): | ||
draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we have different colors for te siferent soure distribs ()and maybe black/gray for the barycenter? |
||
draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) | ||
ax.axis(axis) | ||
ax.axis("off") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
|
||
""" | ||
|
||
# Author: Eloi Tanguy <eloi.tanguy@u-paris> | ||
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> | ||
# Remi Flamary <[email protected]> | ||
# Julie Delon <[email protected]> | ||
# | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
|
||
""" | ||
|
||
# Author: Eloi Tanguy <eloi.tanguy@u-paris> | ||
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> | ||
# Remi Flamary <[email protected]> | ||
# Julie Delon <[email protected]> | ||
# | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cold you please do a function fr that (torch only) in pot that can be used if B is not passed (deafult=None). This will simplify the call for torch cost and keep the possibility to compute closed form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that it would be nice to have B accept a warm start for X0 also