|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +===================================== |
| 4 | +Gaussian Mixture Model OT Barycenters |
| 5 | +===================================== |
| 6 | +
|
| 7 | +This example illustrates the computation of a barycenter between Gaussian |
| 8 | +Mixtures in the sense of GMM-OT [69]. This computation is done using the |
| 9 | +fixed-point method for OT barycenters with generic costs [74], for which POT |
| 10 | +provides a general solver, and a specific GMM solver. Note that this is a |
| 11 | +'free-support' method, implying that the number of components of the barycenter |
| 12 | +GMM and their weights are fixed. |
| 13 | +
|
| 14 | +The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over |
| 15 | +the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the |
| 16 | +Bures-Wasserstein manifold), and to compute barycenters with respect to the |
| 17 | +2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a |
| 18 | +gaussian mixture is a finite combination of Diracs on specific gaussians, and |
| 19 | +two mixtures are compared with the 2-Wasserstein distance on this space with |
| 20 | +ground cost the squared Bures distance between gaussians. |
| 21 | +
|
| 22 | +[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space |
| 23 | +of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. |
| 24 | +
|
| 25 | +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing |
| 26 | +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 |
| 27 | +(2024) |
| 28 | +
|
| 29 | +""" |
| 30 | + |
| 31 | +# Author: Eloi Tanguy <[email protected]> |
| 32 | +# |
| 33 | +# License: MIT License |
| 34 | + |
| 35 | +# sphinx_gallery_thumbnail_number = 1 |
| 36 | + |
| 37 | +# %% |
| 38 | +# Generate data |
| 39 | +import numpy as np |
| 40 | +import matplotlib.pyplot as plt |
| 41 | +from matplotlib.patches import Ellipse |
| 42 | +import ot |
| 43 | +from ot.gmm import gmm_barycenter_fixed_point |
| 44 | + |
| 45 | + |
| 46 | +K = 3 # number of GMMs |
| 47 | +d = 2 # dimension |
| 48 | +n = 6 # number of components of the desired barycenter |
| 49 | + |
| 50 | + |
| 51 | +def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): |
| 52 | + rng = np.random.RandomState(seed=seed) |
| 53 | + means = rng.randn(K, d) |
| 54 | + P = rng.randn(K, d, d) * cov_scale |
| 55 | + # C[k] = P[k] @ P[k]^T + min_cov_eig * I |
| 56 | + covariances = np.einsum("kab,kcb->kac", P, P) |
| 57 | + covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) |
| 58 | + weights = rng.random(K) |
| 59 | + weights /= np.sum(weights) |
| 60 | + return means, covariances, weights |
| 61 | + |
| 62 | + |
| 63 | +m_list = [5, 6, 7] # number of components in each GMM |
| 64 | +offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] |
| 65 | +means_list = [] # list of means for each GMM |
| 66 | +covs_list = [] # list of covariances for each GMM |
| 67 | +w_list = [] # list of weights for each GMM |
| 68 | + |
| 69 | +# generate GMMs |
| 70 | +for k in range(K): |
| 71 | + means, covs, b = get_random_gmm( |
| 72 | + m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 |
| 73 | + ) |
| 74 | + means = means / 2 + offsets[k][None, :] |
| 75 | + means_list.append(means) |
| 76 | + covs_list.append(covs) |
| 77 | + w_list.append(b) |
| 78 | + |
| 79 | +# %% |
| 80 | +# Compute the barycenter using the fixed-point method |
| 81 | +init_means, init_covs, _ = get_random_gmm(n, d, seed=0) |
| 82 | +weights = ot.unif(K) # barycenter coefficients |
| 83 | +means_bar, covs_bar, log = gmm_barycenter_fixed_point( |
| 84 | + means_list, |
| 85 | + covs_list, |
| 86 | + w_list, |
| 87 | + init_means, |
| 88 | + init_covs, |
| 89 | + weights, |
| 90 | + iterations=3, |
| 91 | + log=True, |
| 92 | +) |
| 93 | + |
| 94 | + |
| 95 | +# %% |
| 96 | +# Define plotting functions |
| 97 | + |
| 98 | + |
| 99 | +# draw a covariance ellipse |
| 100 | +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): |
| 101 | + def eigsorted(cov): |
| 102 | + vals, vecs = np.linalg.eigh(cov) |
| 103 | + order = vals.argsort()[::-1].copy() |
| 104 | + return vals[order], vecs[:, order] |
| 105 | + |
| 106 | + vals, vecs = eigsorted(C) |
| 107 | + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) |
| 108 | + w, h = 2 * nstd * np.sqrt(vals) |
| 109 | + ell = Ellipse( |
| 110 | + xy=(mu[0], mu[1]), |
| 111 | + width=w, |
| 112 | + height=h, |
| 113 | + alpha=alpha, |
| 114 | + angle=theta, |
| 115 | + facecolor=color, |
| 116 | + edgecolor=color, |
| 117 | + label=label, |
| 118 | + fill=True, |
| 119 | + ) |
| 120 | + if ax is None: |
| 121 | + ax = plt.gca() |
| 122 | + ax.add_artist(ell) |
| 123 | + |
| 124 | + |
| 125 | +# draw a gmm as a set of ellipses with weights shown in alpha value |
| 126 | +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): |
| 127 | + for k in range(ms.shape[0]): |
| 128 | + draw_cov( |
| 129 | + ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax |
| 130 | + ) |
| 131 | + |
| 132 | + |
| 133 | +# %% |
| 134 | +# Plot the results |
| 135 | +fig, ax = plt.subplots(figsize=(6, 6)) |
| 136 | +axis = [-4, 4, -2, 6] |
| 137 | +ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) |
| 138 | +for k in range(K): |
| 139 | + draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) |
| 140 | +draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) |
| 141 | +ax.axis(axis) |
| 142 | +ax.axis("off") |
| 143 | + |
| 144 | +# %% |
0 commit comments