@@ -541,6 +541,7 @@ def bures_barycenter_gradient_descent(
541
541
log = False ,
542
542
step_size = 1 ,
543
543
batch_size = None ,
544
+ averaged = False ,
544
545
nx = None ,
545
546
):
546
547
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
@@ -570,6 +571,8 @@ def bures_barycenter_gradient_descent(
570
571
step size for the gradient descent, 1 by default
571
572
batch_size : int, optional
572
573
batch size if use a stochastic gradient descent
574
+ averaged : bool, optional
575
+ if True, use the averaged procedure of :ref:`[74] <references-OT-bures-barycenter-gradient_descent>`
573
576
nx : module, optional
574
577
The numerical backend module to use. If not provided, the backend will
575
578
be fetched from the input matrices `C`.
@@ -607,7 +610,9 @@ def bures_barycenter_gradient_descent(
607
610
Cb = nx .mean (C * weights [:, None , None ], axis = 0 )
608
611
Id = nx .eye (C .shape [- 1 ], type_as = Cb )
609
612
610
- L_grads = []
613
+ L_diff = []
614
+
615
+ Cb_averaged = nx .copy (Cb )
611
616
612
617
for it in range (num_iter ):
613
618
Cb12 = nx .sqrtm (Cb )
@@ -627,40 +632,38 @@ def bures_barycenter_gradient_descent(
627
632
628
633
# step size from [74] (page 15)
629
634
step_size = 2 / (0.7 * (it + 2 / 0.7 + 1 ))
630
-
631
- # TODO: Add one where we take samples in order, + averaging? cf [74]
632
635
else : # gradient descent
633
636
M = nx .sqrtm (nx .einsum ("ij,njk,kl -> nil" , Cb12 , C , Cb12 ))
634
637
ot_maps = nx .einsum ("ij,njk,kl -> nil" , Cb12_ , M , Cb12_ )
635
638
grad_bw = Id - nx .sum (ot_maps * weights [:, None , None ], axis = 0 )
636
639
637
640
Cnew = exp_bures (Cb , - step_size * grad_bw , nx = nx )
638
641
642
+ if averaged :
643
+ # ot map between Cb_averaged and Cnew
644
+ Cb_averaged12 = nx .sqrtm (Cb_averaged )
645
+ Cb_averaged12inv = nx .inv (Cb_averaged12 )
646
+ M = nx .sqrtm (nx .einsum ("ij,jk,kl->il" , Cb_averaged12 , Cnew , Cb_averaged12 ))
647
+ ot_map = nx .einsum ("ij,jk,kl->il" , Cb_averaged12inv , M , Cb_averaged12inv )
648
+ map = Id * step_size / (step_size + 1 ) + ot_map / (step_size + 1 )
649
+ Cb_averaged = nx .einsum ("ij,jk,kl->il" , map , Cb_averaged , map )
650
+
639
651
# check convergence
640
- if batch_size is not None and batch_size < n :
641
- # TODO: criteria for SGD: on gradients? + test SGD
642
- # TOO slow, test with value? (but don't want to compute the full barycenter)
643
- # + need to make bures_wasserstein_distance batchable (TODO)
644
- L_grads .append (nx .sum (grad_bw ** 2 ))
645
- diff = np .mean (L_grads )
646
-
647
- # L_values.append(nx.norm(Cb - Cnew))
648
- # print(diff, np.mean(L_values))
649
- else :
650
- diff = nx .norm (Cb - Cnew )
652
+ L_diff .append (nx .norm (Cb - Cnew ))
651
653
652
- if diff <= eps :
654
+ # Criteria to stop
655
+ if np .mean (L_diff [- 100 :]) <= eps :
653
656
break
654
657
655
658
Cb = Cnew
656
659
657
- if diff > eps :
658
- print ( "Dit not converge." )
660
+ if averaged :
661
+ Cb = Cb_averaged
659
662
660
663
if log :
661
664
dict_log = {}
662
665
dict_log ["num_iter" ] = it
663
- dict_log ["final_diff" ] = diff
666
+ dict_log ["final_diff" ] = L_diff [ - 1 ]
664
667
return Cb , dict_log
665
668
else :
666
669
return Cb
@@ -708,7 +711,8 @@ def bures_wasserstein_barycenter(
708
711
weights : array-like (k), optional
709
712
weights for each distribution
710
713
method : str
711
- method used for the solver, either 'fixed_point' or 'gradient_descent'
714
+ method used for the solver, either 'fixed_point', 'gradient_descent', 'stochastic_gradient_descent' or
715
+ 'averaged_stochastic_gradient_descent'
712
716
num_iter : int, optional
713
717
number of iteration for the fixed point algorithm
714
718
eps : float, optional
@@ -756,15 +760,35 @@ def bures_wasserstein_barycenter(
756
760
# Compute the mean barycenter
757
761
mb = nx .sum (m * weights [:, None ], axis = 0 )
758
762
759
- if method == "gradient_descent" or batch_size is not None :
763
+ if method == "gradient_descent" :
760
764
out = bures_barycenter_gradient_descent (
761
765
C ,
762
766
weights = weights ,
763
767
num_iter = num_iter ,
764
768
eps = eps ,
765
769
log = log ,
766
770
step_size = step_size ,
767
- batch_size = batch_size ,
771
+ nx = nx ,
772
+ )
773
+ elif method == "stochastic_gradient_descent" :
774
+ out = bures_barycenter_gradient_descent (
775
+ C ,
776
+ weights = weights ,
777
+ num_iter = num_iter ,
778
+ eps = eps ,
779
+ log = log ,
780
+ batch_size = 1 if batch_size is None else batch_size ,
781
+ nx = nx ,
782
+ )
783
+ elif method == "averaged_stochastic_gradient_descent" :
784
+ out = bures_barycenter_gradient_descent (
785
+ C ,
786
+ weights = weights ,
787
+ num_iter = num_iter ,
788
+ eps = eps ,
789
+ log = log ,
790
+ batch_size = 1 if batch_size is None else batch_size ,
791
+ averaged = True ,
768
792
nx = nx ,
769
793
)
770
794
elif method == "fixed_point" :
0 commit comments