Skip to content

Commit 21bf86b

Browse files
committed
examples: ot bar with projections onto circles + gmm ot bar
1 parent 0b6217b commit 21bf86b

File tree

3 files changed

+149
-7
lines changed

3 files changed

+149
-7
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,4 @@ Artificial Intelligence.
392392

393393
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
394394

395-
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
396-
Barycentres of Measures for Generic Transport
397-
Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
395+
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)

examples/barycenters/plot_barycenter_generic_cost.py renamed to examples/barycenters/plot_free_support_barycenter_generic_cost.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
OT Barycenter with Generic Costs Demo
55
=====================================
66
7-
This example illustrates the computation of an Optimal Transport for a ground
8-
cost that is not a power of a norm. We take the example of ground costs
7+
This example illustrates the computation of an Optimal Transport Barycenter for
8+
a ground cost that is not a power of a norm. We take the example of ground costs
99
:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
1010
projection onto a circle k. This is an example of the fixed-point barycenter
1111
solver introduced in [74] which generalises [20] and [43].
@@ -15,8 +15,8 @@
1515
:math:`x` with Pytorch.
1616
1717
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
18-
Barycentres of Measures for Generic Transport Costs.
19-
arXiv preprint 2501.04016 (2024)
18+
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
19+
(2024)
2020
2121
[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
2222
Barycenters. InternationalConference in Machine Learning
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=====================================
4+
Gaussian Mixture Model OT Barycenters
5+
=====================================
6+
7+
This example illustrates the computation of a barycenter between Gaussian
8+
Mixtures in the sense of GMM-OT [69]. This computation is done using the
9+
fixed-point method for OT barycenters with generic costs [74], for which POT
10+
provides a general solver, and a specific GMM solver. Note that this is a
11+
'free-support' method, implying that the number of components of the barycenter
12+
GMM and their weights are fixed.
13+
14+
The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over
15+
the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the
16+
Bures-Wasserstein manifold), and to compute barycenters with respect to the
17+
2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
18+
gaussian mixture is a finite combination of Diracs on specific gaussians, and
19+
two mixtures are compared with the 2-Wasserstein distance on this space with
20+
ground cost the squared Bures distance between gaussians.
21+
22+
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
23+
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
24+
25+
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
26+
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
27+
(2024)
28+
29+
"""
30+
31+
# Author: Eloi Tanguy <[email protected]>
32+
#
33+
# License: MIT License
34+
35+
# sphinx_gallery_thumbnail_number = 1
36+
37+
# %%
38+
# Generate data
39+
import numpy as np
40+
import matplotlib.pyplot as plt
41+
from matplotlib.patches import Ellipse
42+
import ot
43+
from ot.gmm import gmm_barycenter_fixed_point
44+
45+
46+
K = 3 # number of GMMs
47+
d = 2 # dimension
48+
n = 6 # number of components of the desired barycenter
49+
50+
51+
def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):
52+
rng = np.random.RandomState(seed=seed)
53+
means = rng.randn(K, d)
54+
P = rng.randn(K, d, d) * cov_scale
55+
# C[k] = P[k] @ P[k]^T + min_cov_eig * I
56+
covariances = np.einsum("kab,kcb->kac", P, P)
57+
covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])
58+
weights = rng.random(K)
59+
weights /= np.sum(weights)
60+
return means, covariances, weights
61+
62+
63+
m_list = [5, 6, 7] # number of components in each GMM
64+
offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]
65+
means_list = [] # list of means for each GMM
66+
covs_list = [] # list of covariances for each GMM
67+
w_list = [] # list of weights for each GMM
68+
69+
# generate GMMs
70+
for k in range(K):
71+
means, covs, b = get_random_gmm(
72+
m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5
73+
)
74+
means = means / 2 + offsets[k][None, :]
75+
means_list.append(means)
76+
covs_list.append(covs)
77+
w_list.append(b)
78+
79+
# %%
80+
# Compute the barycenter using the fixed-point method
81+
init_means, init_covs, _ = get_random_gmm(n, d, seed=0)
82+
weights = ot.unif(K) # barycenter coefficients
83+
means_bar, covs_bar, log = gmm_barycenter_fixed_point(
84+
means_list,
85+
covs_list,
86+
w_list,
87+
init_means,
88+
init_covs,
89+
weights,
90+
iterations=3,
91+
log=True,
92+
)
93+
94+
95+
# %%
96+
# Define plotting functions
97+
98+
99+
# draw a covariance ellipse
100+
def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):
101+
def eigsorted(cov):
102+
vals, vecs = np.linalg.eigh(cov)
103+
order = vals.argsort()[::-1].copy()
104+
return vals[order], vecs[:, order]
105+
106+
vals, vecs = eigsorted(C)
107+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
108+
w, h = 2 * nstd * np.sqrt(vals)
109+
ell = Ellipse(
110+
xy=(mu[0], mu[1]),
111+
width=w,
112+
height=h,
113+
alpha=alpha,
114+
angle=theta,
115+
facecolor=color,
116+
edgecolor=color,
117+
label=label,
118+
fill=True,
119+
)
120+
if ax is None:
121+
ax = plt.gca()
122+
ax.add_artist(ell)
123+
124+
125+
# draw a gmm as a set of ellipses with weights shown in alpha value
126+
def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
127+
for k in range(ms.shape[0]):
128+
draw_cov(
129+
ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax
130+
)
131+
132+
133+
# %%
134+
# Plot the results
135+
fig, ax = plt.subplots(figsize=(6, 6))
136+
axis = [-4, 4, -2, 6]
137+
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
138+
for k in range(K):
139+
draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax)
140+
draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax)
141+
ax.axis(axis)
142+
ax.axis("off")
143+
144+
# %%

0 commit comments

Comments
 (0)