Skip to content

Commit fe3d9db

Browse files
committed
Add arg backend
1 parent 0b20759 commit fe3d9db

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

ot/gaussian.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def empirical_bures_wasserstein_mapping(
200200
return A, b
201201

202202

203-
def bures_distance(Cs, Ct, log=False):
203+
def bures_distance(Cs, Ct, log=False, nx=None):
204204
r"""Return Bures distance.
205205
206206
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
@@ -217,7 +217,9 @@ def bures_distance(Cs, Ct, log=False):
217217
covariance of the target distribution
218218
log : bool, optional
219219
record log if True
220-
220+
nx : module, optional
221+
The numerical backend module to use. If not provided, the backend will
222+
be fetched from the input matrices `Cs, Ct`.
221223
222224
Returns
223225
-------
@@ -236,7 +238,11 @@ def bures_distance(Cs, Ct, log=False):
236238
Transport", 2018.
237239
"""
238240
Cs, Ct = list_to_array(Cs, Ct)
239-
nx = get_backend(Cs, Ct)
241+
242+
if nx is None:
243+
nx = get_backend(Cs, Ct)
244+
245+
assert Cs.shape[-1] == Ct.shape[-1], "All Gaussian must have the same dimension"
240246

241247
Cs12 = nx.sqrtm(Cs)
242248

@@ -326,10 +332,10 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
326332
), "All Gaussian must have the same dimension"
327333

328334
if log:
329-
bw, log_dict = bures_distance(Cs, Ct, log)
335+
bw, log_dict = bures_distance(Cs, Ct, log=log, nx=nx)
330336
Cs12 = log_dict["Cs12"]
331337
else:
332-
bw = bures_distance(Cs, Ct)
338+
bw = bures_distance(Cs, Ct, nx=nx)
333339

334340
if len(ms.shape) == 1 and len(mt.shape) == 1:
335341
# Return float
@@ -440,7 +446,9 @@ def empirical_bures_wasserstein_distance(
440446
return W
441447

442448

443-
def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=False):
449+
def bures_barycenter_fixpoint(
450+
C, weights=None, num_iter=1000, eps=1e-7, log=False, nx=None
451+
):
444452
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
445453
446454
The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n`
@@ -469,6 +477,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
469477
tolerance for the fixed point algorithm
470478
log : bool, optional
471479
record log if True
480+
nx : module, optional
481+
The numerical backend module to use. If not provided, the backend will
482+
be fetched from the input matrices `C`.
472483
473484
Returns
474485
-------
@@ -485,9 +496,10 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
485496
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
486497
2011.
487498
"""
488-
nx = get_backend(
489-
*C,
490-
)
499+
if nx is None:
500+
nx = get_backend(
501+
*C,
502+
)
491503

492504
if weights is None:
493505
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]
@@ -522,7 +534,14 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
522534

523535

524536
def bures_barycenter_gradient_descent(
525-
C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None
537+
C,
538+
weights=None,
539+
num_iter=1000,
540+
eps=1e-7,
541+
log=False,
542+
step_size=1,
543+
batch_size=None,
544+
nx=None,
526545
):
527546
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
528547
@@ -551,6 +570,9 @@ def bures_barycenter_gradient_descent(
551570
step size for the gradient descent, 1 by default
552571
batch_size : int, optional
553572
batch size if use a stochastic gradient descent
573+
nx : module, optional
574+
The numerical backend module to use. If not provided, the backend will
575+
be fetched from the input matrices `C`.
554576
555577
Returns
556578
-------
@@ -571,9 +593,10 @@ def bures_barycenter_gradient_descent(
571593
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
572594
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
573595
"""
574-
nx = get_backend(
575-
*C,
576-
)
596+
if nx is None:
597+
nx = get_backend(
598+
*C,
599+
)
577600

578601
n = C.shape[0]
579602

@@ -742,10 +765,11 @@ def bures_wasserstein_barycenter(
742765
log=log,
743766
step_size=step_size,
744767
batch_size=batch_size,
768+
nx=nx,
745769
)
746770
elif method == "fixed_point":
747771
out = bures_barycenter_fixpoint(
748-
C, weights=weights, num_iter=num_iter, eps=eps, log=log
772+
C, weights=weights, num_iter=num_iter, eps=eps, log=log, nx=nx
749773
)
750774
else:
751775
raise ValueError("Unknown method '%s'." % method)

0 commit comments

Comments
 (0)