-
Notifications
You must be signed in to change notification settings - Fork 514
/
Copy pathplot_gmm_barycenter.py
142 lines (117 loc) · 4.32 KB
/
plot_gmm_barycenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
=====================================
Gaussian Mixture Model OT Barycenters
=====================================
This example illustrates the computation of a barycenter between Gaussian
Mixtures in the sense of GMM-OT [69]. This computation is done using the
fixed-point method for OT barycenters with generic costs [76], for which POT
provides a general solver, and a specific GMM solver. Note that this is a
'free-support' method, implying that the number of components of the barycenter
GMM and their weights are fixed.
The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over
the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the
Bures-Wasserstein manifold), and to compute barycenters with respect to the
2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
gaussian mixture is a finite combination of Diracs on specific gaussians, and
two mixtures are compared with the 2-Wasserstein distance on this space, where
ground cost the squared Bures distance between gaussians.
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
(2024)
"""
# Author: Eloi Tanguy <[email protected]>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 1
# %%
# Generate data
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import ot
from ot.gmm import gmm_barycenter_fixed_point
K = 3 # number of GMMs
d = 2 # dimension
n = 6 # number of components of the desired barycenter
def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):
rng = np.random.RandomState(seed=seed)
means = rng.randn(K, d)
P = rng.randn(K, d, d) * cov_scale
# C[k] = P[k] @ P[k]^T + min_cov_eig * I
covariances = np.einsum("kab,kcb->kac", P, P)
covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])
weights = rng.random(K)
weights /= np.sum(weights)
return means, covariances, weights
m_list = [5, 6, 7] # number of components in each GMM
offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]
means_list = [] # list of means for each GMM
covs_list = [] # list of covariances for each GMM
w_list = [] # list of weights for each GMM
# generate GMMs
for k in range(K):
means, covs, b = get_random_gmm(
m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5
)
means = means / 2 + offsets[k][None, :]
means_list.append(means)
covs_list.append(covs)
w_list.append(b)
# %%
# Compute the barycenter using the fixed-point method
init_means, init_covs, _ = get_random_gmm(n, d, seed=0)
weights = ot.unif(K) # barycenter coefficients
means_bar, covs_bar, log = gmm_barycenter_fixed_point(
means_list,
covs_list,
w_list,
init_means,
init_covs,
weights,
iterations=3,
log=True,
)
# %%
# Define plotting functions
# draw a covariance ellipse
def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):
def eigsorted(cov):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1].copy()
return vals[order], vecs[:, order]
vals, vecs = eigsorted(C)
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
w, h = 2 * nstd * np.sqrt(vals)
ell = Ellipse(
xy=(mu[0], mu[1]),
width=w,
height=h,
alpha=alpha,
angle=theta,
facecolor=color,
edgecolor=color,
label=label,
fill=True,
)
if ax is None:
ax = plt.gca()
ax.add_artist(ell)
# draw a gmm as a set of ellipses with weights shown in alpha value
def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
for k in range(ms.shape[0]):
draw_cov(
ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax
)
# %%
# Plot the results
fig, ax = plt.subplots(figsize=(6, 6))
axis = [-4, 4, -2, 6]
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
for k in range(K):
draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax)
draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax)
ax.axis(axis)
ax.axis("off")