@@ -200,7 +200,7 @@ def empirical_bures_wasserstein_mapping(
200
200
return A , b
201
201
202
202
203
- def bures_distance (Cs , Ct , log = False ):
203
+ def bures_distance (Cs , Ct , log = False , nx = None ):
204
204
r"""Return Bures distance.
205
205
206
206
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
@@ -217,7 +217,9 @@ def bures_distance(Cs, Ct, log=False):
217
217
covariance of the target distribution
218
218
log : bool, optional
219
219
record log if True
220
-
220
+ nx : module, optional
221
+ The numerical backend module to use. If not provided, the backend will
222
+ be fetched from the input matrices `Cs, Ct`.
221
223
222
224
Returns
223
225
-------
@@ -236,7 +238,11 @@ def bures_distance(Cs, Ct, log=False):
236
238
Transport", 2018.
237
239
"""
238
240
Cs , Ct = list_to_array (Cs , Ct )
239
- nx = get_backend (Cs , Ct )
241
+
242
+ if nx is None :
243
+ nx = get_backend (Cs , Ct )
244
+
245
+ assert Cs .shape [- 1 ] == Ct .shape [- 1 ], "All Gaussian must have the same dimension"
240
246
241
247
Cs12 = nx .sqrtm (Cs )
242
248
@@ -326,10 +332,10 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
326
332
), "All Gaussian must have the same dimension"
327
333
328
334
if log :
329
- bw , log_dict = bures_distance (Cs , Ct , log )
335
+ bw , log_dict = bures_distance (Cs , Ct , log = log , nx = nx )
330
336
Cs12 = log_dict ["Cs12" ]
331
337
else :
332
- bw = bures_distance (Cs , Ct )
338
+ bw = bures_distance (Cs , Ct , nx = nx )
333
339
334
340
if len (ms .shape ) == 1 and len (mt .shape ) == 1 :
335
341
# Return float
@@ -440,7 +446,9 @@ def empirical_bures_wasserstein_distance(
440
446
return W
441
447
442
448
443
- def bures_barycenter_fixpoint (C , weights = None , num_iter = 1000 , eps = 1e-7 , log = False ):
449
+ def bures_barycenter_fixpoint (
450
+ C , weights = None , num_iter = 1000 , eps = 1e-7 , log = False , nx = None
451
+ ):
444
452
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
445
453
446
454
The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n`
@@ -469,6 +477,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
469
477
tolerance for the fixed point algorithm
470
478
log : bool, optional
471
479
record log if True
480
+ nx : module, optional
481
+ The numerical backend module to use. If not provided, the backend will
482
+ be fetched from the input matrices `C`.
472
483
473
484
Returns
474
485
-------
@@ -485,9 +496,10 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
485
496
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
486
497
2011.
487
498
"""
488
- nx = get_backend (
489
- * C ,
490
- )
499
+ if nx is None :
500
+ nx = get_backend (
501
+ * C ,
502
+ )
491
503
492
504
if weights is None :
493
505
weights = nx .ones (C .shape [0 ], type_as = C [0 ]) / C .shape [0 ]
@@ -522,7 +534,14 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
522
534
523
535
524
536
def bures_barycenter_gradient_descent (
525
- C , weights = None , num_iter = 1000 , eps = 1e-7 , log = False , step_size = 1 , batch_size = None
537
+ C ,
538
+ weights = None ,
539
+ num_iter = 1000 ,
540
+ eps = 1e-7 ,
541
+ log = False ,
542
+ step_size = 1 ,
543
+ batch_size = None ,
544
+ nx = None ,
526
545
):
527
546
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
528
547
@@ -551,6 +570,9 @@ def bures_barycenter_gradient_descent(
551
570
step size for the gradient descent, 1 by default
552
571
batch_size : int, optional
553
572
batch size if use a stochastic gradient descent
573
+ nx : module, optional
574
+ The numerical backend module to use. If not provided, the backend will
575
+ be fetched from the input matrices `C`.
554
576
555
577
Returns
556
578
-------
@@ -571,9 +593,10 @@ def bures_barycenter_gradient_descent(
571
593
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
572
594
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
573
595
"""
574
- nx = get_backend (
575
- * C ,
576
- )
596
+ if nx is None :
597
+ nx = get_backend (
598
+ * C ,
599
+ )
577
600
578
601
n = C .shape [0 ]
579
602
@@ -742,10 +765,11 @@ def bures_wasserstein_barycenter(
742
765
log = log ,
743
766
step_size = step_size ,
744
767
batch_size = batch_size ,
768
+ nx = nx ,
745
769
)
746
770
elif method == "fixed_point" :
747
771
out = bures_barycenter_fixpoint (
748
- C , weights = weights , num_iter = num_iter , eps = eps , log = log
772
+ C , weights = weights , num_iter = num_iter , eps = eps , log = log , nx = nx
749
773
)
750
774
else :
751
775
raise ValueError ("Unknown method '%s'." % method )
0 commit comments