|
13 | 13 | from .lp import emd2, emd
|
14 | 14 | import numpy as np
|
15 | 15 | from .utils import dist
|
16 |
| -from .gaussian import bures_wasserstein_mapping |
| 16 | +from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter |
17 | 17 |
|
18 | 18 |
|
19 | 19 | def gaussian_logpdf(x, m, C):
|
@@ -440,3 +440,115 @@ def Tk0k1(k0, k1):
|
440 | 440 | ]
|
441 | 441 | )
|
442 | 442 | return nx.sum(mat, axis=(0, 1))
|
| 443 | + |
| 444 | + |
| 445 | +def solve_gmm_barycenter_fixed_point( |
| 446 | + means, |
| 447 | + covs, |
| 448 | + means_list, |
| 449 | + covs_list, |
| 450 | + b_list, |
| 451 | + weights, |
| 452 | + max_its=300, |
| 453 | + log=False, |
| 454 | + barycentric_proj_method="euclidean", |
| 455 | +): |
| 456 | + r""" |
| 457 | + Solves the GMM OT barycenter problem using the fixed point algorithm. |
| 458 | +
|
| 459 | + Parameters |
| 460 | + ---------- |
| 461 | + means : array-like |
| 462 | + Initial (n, d) GMM means. |
| 463 | + covs : array-like |
| 464 | + Initial (n, d, d) GMM covariances. |
| 465 | + means_list : list of array-like |
| 466 | + List of K (m_k, d) GMM means. |
| 467 | + covs_list : list of array-like |
| 468 | + List of K (m_k, d, d) GMM covariances. |
| 469 | + b_list : list of array-like |
| 470 | + List of K (m_k) arrays of weights. |
| 471 | + weights : array-like |
| 472 | + Array (K,) of the barycentre coefficients. |
| 473 | + max_its : int, optional |
| 474 | + Maximum number of iterations (default is 300). |
| 475 | + log : bool, optional |
| 476 | + Whether to return the list of iterations (default is False). |
| 477 | + barycentric_proj_method : str, optional |
| 478 | + Method to project the barycentre weights: 'euclidean' (default) or 'bures'. |
| 479 | +
|
| 480 | + Returns |
| 481 | + ------- |
| 482 | + means : array-like |
| 483 | + (n, d) barycentre GMM means. |
| 484 | + covs : array-like |
| 485 | + (n, d, d) barycentre GMM covariances. |
| 486 | + log_dict : dict, optional |
| 487 | + Dictionary containing the list of iterations if log is True. |
| 488 | + """ |
| 489 | + nx = get_backend(means, covs[0], means_list[0], covs_list[0]) |
| 490 | + K = len(means_list) |
| 491 | + n = means.shape[0] |
| 492 | + d = means.shape[1] |
| 493 | + means_its = [means.copy()] |
| 494 | + covs_its = [covs.copy()] |
| 495 | + a = nx.ones(n, type_as=means) / n |
| 496 | + |
| 497 | + for _ in range(max_its): |
| 498 | + pi_list = [ |
| 499 | + gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k]) |
| 500 | + for k in range(K) |
| 501 | + ] |
| 502 | + |
| 503 | + means_selection, covs_selection = None, None |
| 504 | + # in the euclidean case, the selection of Gaussians from each K sources |
| 505 | + # comes from a barycentric projection is a convex combination of the |
| 506 | + # selected means and covariances, which can be computed without a |
| 507 | + # for loop on i |
| 508 | + if barycentric_proj_method == "euclidean": |
| 509 | + means_selection = nx.zeros((n, K, d), type_as=means) |
| 510 | + covs_selection = nx.zeros((n, K, d, d), type_as=means) |
| 511 | + |
| 512 | + for k in range(K): |
| 513 | + means_selection[:, k, :] = n * pi_list[k] @ means_list[k] |
| 514 | + covs_selection[:, k, :, :] = ( |
| 515 | + nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n |
| 516 | + ) |
| 517 | + |
| 518 | + # each component i of the barycentre will be a Bures barycentre of the |
| 519 | + # selected components of the K GMMs. In the 'bures' barycentric |
| 520 | + # projection option, the selected components are also Bures barycentres. |
| 521 | + for i in range(n): |
| 522 | + # means_slice_i (K, d) is the selected means, each comes from a |
| 523 | + # Gaussian barycentre along the disintegration of pi_k at i |
| 524 | + # covs_slice_i (K, d, d) are the selected covariances |
| 525 | + means_selection_i = [] |
| 526 | + covs_selection_i = [] |
| 527 | + |
| 528 | + # use previous computation (convex combination) |
| 529 | + if barycentric_proj_method == "euclidean": |
| 530 | + means_selection_i = means_selection[i] |
| 531 | + covs_selection_i = covs_selection[i] |
| 532 | + |
| 533 | + # compute Bures barycentre of the selected components |
| 534 | + elif barycentric_proj_method == "bures": |
| 535 | + w = (1 / a[i]) * pi_list[k][i, :] |
| 536 | + for k in range(K): |
| 537 | + m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) |
| 538 | + means_selection_i.append(m) |
| 539 | + covs_selection_i.append(C) |
| 540 | + |
| 541 | + else: |
| 542 | + raise ValueError("Unknown barycentric_proj_method") |
| 543 | + |
| 544 | + means[i], covs[i] = bures_wasserstein_barycenter( |
| 545 | + means_selection_i, covs_selection_i, weights |
| 546 | + ) |
| 547 | + |
| 548 | + if log: |
| 549 | + means_its.append(means.copy()) |
| 550 | + covs_its.append(covs.copy()) |
| 551 | + |
| 552 | + if log: |
| 553 | + return means, covs, {"means_its": means_its, "covs_its": covs_its} |
| 554 | + return means, covs |
0 commit comments