Skip to content

Commit 37b9c80

Browse files
committed
python 3.13 in test workflow + added ggmot barycenter (WIP)
1 parent ccf608a commit 37b9c80

File tree

2 files changed

+114
-2
lines changed

2 files changed

+114
-2
lines changed

.github/workflows/build_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
strategy:
4848
max-parallel: 4
4949
matrix:
50-
python-version: ["3.9", "3.10", "3.11", "3.12"]
50+
python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"]
5151

5252
steps:
5353
- uses: actions/checkout@v4

ot/gmm.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .lp import emd2, emd
1414
import numpy as np
1515
from .utils import dist
16-
from .gaussian import bures_wasserstein_mapping
16+
from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter
1717

1818

1919
def gaussian_logpdf(x, m, C):
@@ -440,3 +440,115 @@ def Tk0k1(k0, k1):
440440
]
441441
)
442442
return nx.sum(mat, axis=(0, 1))
443+
444+
445+
def solve_gmm_barycenter_fixed_point(
446+
means,
447+
covs,
448+
means_list,
449+
covs_list,
450+
b_list,
451+
weights,
452+
max_its=300,
453+
log=False,
454+
barycentric_proj_method="euclidean",
455+
):
456+
r"""
457+
Solves the GMM OT barycenter problem using the fixed point algorithm.
458+
459+
Parameters
460+
----------
461+
means : array-like
462+
Initial (n, d) GMM means.
463+
covs : array-like
464+
Initial (n, d, d) GMM covariances.
465+
means_list : list of array-like
466+
List of K (m_k, d) GMM means.
467+
covs_list : list of array-like
468+
List of K (m_k, d, d) GMM covariances.
469+
b_list : list of array-like
470+
List of K (m_k) arrays of weights.
471+
weights : array-like
472+
Array (K,) of the barycentre coefficients.
473+
max_its : int, optional
474+
Maximum number of iterations (default is 300).
475+
log : bool, optional
476+
Whether to return the list of iterations (default is False).
477+
barycentric_proj_method : str, optional
478+
Method to project the barycentre weights: 'euclidean' (default) or 'bures'.
479+
480+
Returns
481+
-------
482+
means : array-like
483+
(n, d) barycentre GMM means.
484+
covs : array-like
485+
(n, d, d) barycentre GMM covariances.
486+
log_dict : dict, optional
487+
Dictionary containing the list of iterations if log is True.
488+
"""
489+
nx = get_backend(means, covs[0], means_list[0], covs_list[0])
490+
K = len(means_list)
491+
n = means.shape[0]
492+
d = means.shape[1]
493+
means_its = [means.copy()]
494+
covs_its = [covs.copy()]
495+
a = nx.ones(n, type_as=means) / n
496+
497+
for _ in range(max_its):
498+
pi_list = [
499+
gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k])
500+
for k in range(K)
501+
]
502+
503+
means_selection, covs_selection = None, None
504+
# in the euclidean case, the selection of Gaussians from each K sources
505+
# comes from a barycentric projection is a convex combination of the
506+
# selected means and covariances, which can be computed without a
507+
# for loop on i
508+
if barycentric_proj_method == "euclidean":
509+
means_selection = nx.zeros((n, K, d), type_as=means)
510+
covs_selection = nx.zeros((n, K, d, d), type_as=means)
511+
512+
for k in range(K):
513+
means_selection[:, k, :] = n * pi_list[k] @ means_list[k]
514+
covs_selection[:, k, :, :] = (
515+
nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n
516+
)
517+
518+
# each component i of the barycentre will be a Bures barycentre of the
519+
# selected components of the K GMMs. In the 'bures' barycentric
520+
# projection option, the selected components are also Bures barycentres.
521+
for i in range(n):
522+
# means_slice_i (K, d) is the selected means, each comes from a
523+
# Gaussian barycentre along the disintegration of pi_k at i
524+
# covs_slice_i (K, d, d) are the selected covariances
525+
means_selection_i = []
526+
covs_selection_i = []
527+
528+
# use previous computation (convex combination)
529+
if barycentric_proj_method == "euclidean":
530+
means_selection_i = means_selection[i]
531+
covs_selection_i = covs_selection[i]
532+
533+
# compute Bures barycentre of the selected components
534+
elif barycentric_proj_method == "bures":
535+
w = (1 / a[i]) * pi_list[k][i, :]
536+
for k in range(K):
537+
m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w)
538+
means_selection_i.append(m)
539+
covs_selection_i.append(C)
540+
541+
else:
542+
raise ValueError("Unknown barycentric_proj_method")
543+
544+
means[i], covs[i] = bures_wasserstein_barycenter(
545+
means_selection_i, covs_selection_i, weights
546+
)
547+
548+
if log:
549+
means_its.append(means.copy())
550+
covs_its.append(covs.copy())
551+
552+
if log:
553+
return means, covs, {"means_its": means_its, "covs_its": covs_its}
554+
return means, covs

0 commit comments

Comments
 (0)