Skip to content

Commit 0b6217b

Browse files
committed
ot bar doc + test coverage
1 parent a20d3f0 commit 0b6217b

File tree

4 files changed

+124
-39
lines changed

4 files changed

+124
-39
lines changed

.github/workflows/build_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
strategy:
4848
max-parallel: 4
4949
matrix:
50-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
50+
python-version: ["3.9", "3.10", "3.11", "3.12"]
5151

5252
steps:
5353
- uses: actions/checkout@v4

ot/gmm.py

+68-35
Original file line numberDiff line numberDiff line change
@@ -442,36 +442,50 @@ def Tk0k1(k0, k1):
442442
return nx.sum(mat, axis=(0, 1))
443443

444444

445-
def solve_gmm_barycenter_fixed_point(
446-
means,
447-
covs,
445+
def gmm_barycenter_fixed_point(
448446
means_list,
449447
covs_list,
450-
b_list,
448+
w_list,
449+
means_init,
450+
covs_init,
451451
weights,
452-
max_its=300,
452+
w_bar=None,
453+
iterations=100,
453454
log=False,
454455
barycentric_proj_method="euclidean",
455456
):
456457
r"""
457-
Solves the GMM OT barycenter problem using the fixed point algorithm.
458+
Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
459+
using the fixed point algorithm (proposed in [74]). The
460+
weights of the barycenter are not optimized, and stay the same as the input
461+
`w_list` or are initialized to uniform.
462+
463+
The algorithm uses barycentric projections of GMM-OT plans, and these can be
464+
computed either through Bures Barycenters (slow but accurate,
465+
barycentric_proj_method='bures') or by convex combination (fast,
466+
barycentric_proj_method='euclidean', default).
467+
468+
This is a special case of the generic free-support barycenter solver
469+
`ot.lp.free_support_barycenter_generic_costs`.
458470
459471
Parameters
460472
----------
461-
means : array-like
462-
Initial (n, d) GMM means.
463-
covs : array-like
464-
Initial (n, d, d) GMM covariances.
465473
means_list : list of array-like
466474
List of K (m_k, d) GMM means.
467475
covs_list : list of array-like
468476
List of K (m_k, d, d) GMM covariances.
469-
b_list : list of array-like
477+
w_list : list of array-like
470478
List of K (m_k) arrays of weights.
479+
means_init : array-like
480+
Initial (n, d) GMM means.
481+
covs_init : array-like
482+
Initial (n, d, d) GMM covariances.
471483
weights : array-like
472484
Array (K,) of the barycentre coefficients.
473-
max_its : int, optional
474-
Maximum number of iterations (default is 300).
485+
w_bar : array-like, optional
486+
Initial weights (n) of the barycentre GMM. If None, initialized to uniform.
487+
iterations : int, optional
488+
Number of iterations (default is 100).
475489
log : bool, optional
476490
Whether to return the list of iterations (default is False).
477491
barycentric_proj_method : str, optional
@@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point(
485499
(n, d, d) barycentre GMM covariances.
486500
log_dict : dict, optional
487501
Dictionary containing the list of iterations if log is True.
502+
503+
References
504+
----------
505+
.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
506+
507+
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
508+
509+
See Also
510+
--------
511+
ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs.
488512
"""
489-
nx = get_backend(means, covs[0], means_list[0], covs_list[0])
513+
nx = get_backend(
514+
means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights
515+
)
490516
K = len(means_list)
491-
n = means.shape[0]
492-
d = means.shape[1]
493-
means_its = [means.copy()]
494-
covs_its = [covs.copy()]
495-
a = nx.ones(n, type_as=means) / n
517+
n = means_init.shape[0]
518+
d = means_init.shape[1]
519+
means_its = [nx.copy(means_init)]
520+
covs_its = [nx.copy(covs_init)]
521+
means, covs = means_init, covs_init
522+
523+
if w_bar is None:
524+
w_bar = nx.ones(n, type_as=means) / n
496525

497-
for _ in range(max_its):
526+
for _ in range(iterations):
498527
pi_list = [
499-
gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k])
528+
gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k])
500529
for k in range(K)
501530
]
502531

532+
# filled in the euclidean case
503533
means_selection, covs_selection = None, None
534+
504535
# in the euclidean case, the selection of Gaussians from each K sources
505-
# comes from a barycentric projection is a convex combination of the
506-
# selected means and covariances, which can be computed without a
507-
# for loop on i
536+
# comes from a barycentric projection: it is a convex combination of the
537+
# selected means and covariances, which can be computed without a
538+
# for loop on i = 0, ..., n -1
508539
if barycentric_proj_method == "euclidean":
509540
means_selection = nx.zeros((n, K, d), type_as=means)
510541
covs_selection = nx.zeros((n, K, d, d), type_as=means)
511-
512542
for k in range(K):
513543
means_selection[:, k, :] = n * pi_list[k] @ means_list[k]
514544
covs_selection[:, k, :, :] = (
@@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point(
519549
# selected components of the K GMMs. In the 'bures' barycentric
520550
# projection option, the selected components are also Bures barycentres.
521551
for i in range(n):
522-
# means_slice_i (K, d) is the selected means, each comes from a
552+
# means_selection_i (K, d) is the selected means, each comes from a
523553
# Gaussian barycentre along the disintegration of pi_k at i
524-
# covs_slice_i (K, d, d) are the selected covariances
525-
means_selection_i = []
526-
covs_selection_i = []
554+
# covs_selection_i (K, d, d) are the selected covariances
555+
means_selection_i = None
556+
covs_selection_i = None
527557

528558
# use previous computation (convex combination)
529559
if barycentric_proj_method == "euclidean":
530560
means_selection_i = means_selection[i]
531561
covs_selection_i = covs_selection[i]
532562

533-
# compute Bures barycentre of the selected components
563+
# compute Bures barycentre of certain components to get the
564+
# selection at i
534565
elif barycentric_proj_method == "bures":
535-
w = (1 / a[i]) * pi_list[k][i, :]
566+
means_selection_i = nx.zeros((K, d), type_as=means)
567+
covs_selection_i = nx.zeros((K, d, d), type_as=means)
536568
for k in range(K):
569+
w = (1 / w_bar[i]) * pi_list[k][i, :]
537570
m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w)
538-
means_selection_i.append(m)
539-
covs_selection_i.append(C)
571+
means_selection_i[k] = m
572+
covs_selection_i[k] = C
540573

541574
else:
542575
raise ValueError("Unknown barycentric_proj_method")
@@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point(
546579
)
547580

548581
if log:
549-
means_its.append(means.copy())
550-
covs_its.append(covs.copy())
582+
means_its.append(nx.copy(means))
583+
covs_its.append(nx.copy(covs))
551584

552585
if log:
553586
return means, covs, {"means_its": means_its, "covs_its": covs_its}

ot/lp/_barycenter_solvers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def free_support_barycenter_generic_costs(
435435
cost_list,
436436
B,
437437
a=None,
438-
numItermax=5,
438+
numItermax=100,
439439
stopThr=1e-5,
440440
log=False,
441441
):
@@ -512,7 +512,7 @@ def free_support_barycenter_generic_costs(
512512
Array of shape (n,) representing weights of the barycenter
513513
measure.Defaults to uniform.
514514
numItermax : int, optional
515-
Maximum number of iterations (default is 5).
515+
Maximum number of iterations (default is 100).
516516
stopThr : float, optional
517517
If the iterations move less than this, terminate (default is 1e-5).
518518
log : bool, optional

test/test_gmm.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for module gaussian"""
22

3-
# Author: Eloi Tanguy <eloi.tanguy@u-paris>
3+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
44
# Remi Flamary <[email protected]>
55
# Julie Delon <[email protected]>
66
#
@@ -17,6 +17,7 @@
1717
gmm_ot_plan,
1818
gmm_ot_apply_map,
1919
gmm_ot_plan_density,
20+
gmm_barycenter_fixed_point,
2021
)
2122

2223
try:
@@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx):
193194

194195
with pytest.raises(AssertionError):
195196
gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t)
197+
198+
199+
@pytest.skip_backend("tf") # skips because of array assignment
200+
@pytest.skip_backend("jax")
201+
def test_gmm_barycenter_fixed_point(nx):
202+
m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx)
203+
means_list = [m_s, m_t]
204+
covs_list = [C_s, C_t]
205+
w_list = [w_s, w_t]
206+
n_iter = 3
207+
n = m_s.shape[0] # number of components of barycenter
208+
means_init = m_s
209+
covs_init = C_s
210+
weights = nx.ones(2, type_as=m_s) / 2 # barycenter coefficients
211+
212+
# with euclidean barycentric projections
213+
means, covs = gmm_barycenter_fixed_point(
214+
means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter
215+
)
216+
217+
# with bures barycentric projections and assigned weights to uniform
218+
means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point(
219+
means_list,
220+
covs_list,
221+
w_list,
222+
means_init,
223+
covs_init,
224+
weights,
225+
iterations=n_iter,
226+
w_bar=nx.ones(n, type_as=m_s) / n,
227+
barycentric_proj_method="bures",
228+
log=True,
229+
)
230+
231+
assert "means_its" in log
232+
assert "covs_its" in log
233+
234+
assert np.allclose(means, means_bures_proj, atol=1e-6)
235+
assert np.allclose(covs, covs_bures_proj, atol=1e-6)
236+
237+
with pytest.raises(ValueError):
238+
gmm_barycenter_fixed_point(
239+
means_list,
240+
covs_list,
241+
w_list,
242+
means_init,
243+
covs_init,
244+
weights,
245+
iterations=n_iter,
246+
barycentric_proj_method="unknown",
247+
)

0 commit comments

Comments
 (0)