@@ -442,36 +442,50 @@ def Tk0k1(k0, k1):
442
442
return nx .sum (mat , axis = (0 , 1 ))
443
443
444
444
445
- def solve_gmm_barycenter_fixed_point (
446
- means ,
447
- covs ,
445
+ def gmm_barycenter_fixed_point (
448
446
means_list ,
449
447
covs_list ,
450
- b_list ,
448
+ w_list ,
449
+ means_init ,
450
+ covs_init ,
451
451
weights ,
452
- max_its = 300 ,
452
+ w_bar = None ,
453
+ iterations = 100 ,
453
454
log = False ,
454
455
barycentric_proj_method = "euclidean" ,
455
456
):
456
457
r"""
457
- Solves the GMM OT barycenter problem using the fixed point algorithm.
458
+ Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
459
+ using the fixed point algorithm (proposed in [74]). The
460
+ weights of the barycenter are not optimized, and stay the same as the input
461
+ `w_list` or are initialized to uniform.
462
+
463
+ The algorithm uses barycentric projections of GMM-OT plans, and these can be
464
+ computed either through Bures Barycenters (slow but accurate,
465
+ barycentric_proj_method='bures') or by convex combination (fast,
466
+ barycentric_proj_method='euclidean', default).
467
+
468
+ This is a special case of the generic free-support barycenter solver
469
+ `ot.lp.free_support_barycenter_generic_costs`.
458
470
459
471
Parameters
460
472
----------
461
- means : array-like
462
- Initial (n, d) GMM means.
463
- covs : array-like
464
- Initial (n, d, d) GMM covariances.
465
473
means_list : list of array-like
466
474
List of K (m_k, d) GMM means.
467
475
covs_list : list of array-like
468
476
List of K (m_k, d, d) GMM covariances.
469
- b_list : list of array-like
477
+ w_list : list of array-like
470
478
List of K (m_k) arrays of weights.
479
+ means_init : array-like
480
+ Initial (n, d) GMM means.
481
+ covs_init : array-like
482
+ Initial (n, d, d) GMM covariances.
471
483
weights : array-like
472
484
Array (K,) of the barycentre coefficients.
473
- max_its : int, optional
474
- Maximum number of iterations (default is 300).
485
+ w_bar : array-like, optional
486
+ Initial weights (n) of the barycentre GMM. If None, initialized to uniform.
487
+ iterations : int, optional
488
+ Number of iterations (default is 100).
475
489
log : bool, optional
476
490
Whether to return the list of iterations (default is False).
477
491
barycentric_proj_method : str, optional
@@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point(
485
499
(n, d, d) barycentre GMM covariances.
486
500
log_dict : dict, optional
487
501
Dictionary containing the list of iterations if log is True.
502
+
503
+ References
504
+ ----------
505
+ .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
506
+
507
+ .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
508
+
509
+ See Also
510
+ --------
511
+ ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs.
488
512
"""
489
- nx = get_backend (means , covs [0 ], means_list [0 ], covs_list [0 ])
513
+ nx = get_backend (
514
+ means_init , covs_init , means_list [0 ], covs_list [0 ], w_list [0 ], weights
515
+ )
490
516
K = len (means_list )
491
- n = means .shape [0 ]
492
- d = means .shape [1 ]
493
- means_its = [means .copy ()]
494
- covs_its = [covs .copy ()]
495
- a = nx .ones (n , type_as = means ) / n
517
+ n = means_init .shape [0 ]
518
+ d = means_init .shape [1 ]
519
+ means_its = [nx .copy (means_init )]
520
+ covs_its = [nx .copy (covs_init )]
521
+ means , covs = means_init , covs_init
522
+
523
+ if w_bar is None :
524
+ w_bar = nx .ones (n , type_as = means ) / n
496
525
497
- for _ in range (max_its ):
526
+ for _ in range (iterations ):
498
527
pi_list = [
499
- gmm_ot_plan (means , means_list [k ], covs , covs_list [k ], a , b_list [k ])
528
+ gmm_ot_plan (means , means_list [k ], covs , covs_list [k ], w_bar , w_list [k ])
500
529
for k in range (K )
501
530
]
502
531
532
+ # filled in the euclidean case
503
533
means_selection , covs_selection = None , None
534
+
504
535
# in the euclidean case, the selection of Gaussians from each K sources
505
- # comes from a barycentric projection is a convex combination of the
506
- # selected means and covariances, which can be computed without a
507
- # for loop on i
536
+ # comes from a barycentric projection: it is a convex combination of the
537
+ # selected means and covariances, which can be computed without a
538
+ # for loop on i = 0, ..., n -1
508
539
if barycentric_proj_method == "euclidean" :
509
540
means_selection = nx .zeros ((n , K , d ), type_as = means )
510
541
covs_selection = nx .zeros ((n , K , d , d ), type_as = means )
511
-
512
542
for k in range (K ):
513
543
means_selection [:, k , :] = n * pi_list [k ] @ means_list [k ]
514
544
covs_selection [:, k , :, :] = (
@@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point(
519
549
# selected components of the K GMMs. In the 'bures' barycentric
520
550
# projection option, the selected components are also Bures barycentres.
521
551
for i in range (n ):
522
- # means_slice_i (K, d) is the selected means, each comes from a
552
+ # means_selection_i (K, d) is the selected means, each comes from a
523
553
# Gaussian barycentre along the disintegration of pi_k at i
524
- # covs_slice_i (K, d, d) are the selected covariances
525
- means_selection_i = []
526
- covs_selection_i = []
554
+ # covs_selection_i (K, d, d) are the selected covariances
555
+ means_selection_i = None
556
+ covs_selection_i = None
527
557
528
558
# use previous computation (convex combination)
529
559
if barycentric_proj_method == "euclidean" :
530
560
means_selection_i = means_selection [i ]
531
561
covs_selection_i = covs_selection [i ]
532
562
533
- # compute Bures barycentre of the selected components
563
+ # compute Bures barycentre of certain components to get the
564
+ # selection at i
534
565
elif barycentric_proj_method == "bures" :
535
- w = (1 / a [i ]) * pi_list [k ][i , :]
566
+ means_selection_i = nx .zeros ((K , d ), type_as = means )
567
+ covs_selection_i = nx .zeros ((K , d , d ), type_as = means )
536
568
for k in range (K ):
569
+ w = (1 / w_bar [i ]) * pi_list [k ][i , :]
537
570
m , C = bures_wasserstein_barycenter (means_list [k ], covs_list [k ], w )
538
- means_selection_i . append ( m )
539
- covs_selection_i . append ( C )
571
+ means_selection_i [ k ] = m
572
+ covs_selection_i [ k ] = C
540
573
541
574
else :
542
575
raise ValueError ("Unknown barycentric_proj_method" )
@@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point(
546
579
)
547
580
548
581
if log :
549
- means_its .append (means .copy ())
550
- covs_its .append (covs .copy ())
582
+ means_its .append (nx .copy (means ))
583
+ covs_its .append (nx .copy (covs ))
551
584
552
585
if log :
553
586
return means , covs , {"means_its" : means_its , "covs_its" : covs_its }
0 commit comments