Skip to content

Commit 41ebffc

Browse files
committed
fix doc
1 parent c640ecb commit 41ebffc

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

ot/gaussian.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def bures_distance(Cs, Ct, log=False, nx=None):
204204
r"""Return Bures distance.
205205
206206
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
207-
given by:
207+
given by (see e.g. Remark 2.31 :ref:`[15] <references-bures-wasserstein-distance>`):
208208
209209
.. math::
210210
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
@@ -223,19 +223,17 @@ def bures_distance(Cs, Ct, log=False, nx=None):
223223
224224
Returns
225225
-------
226-
W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d),
227-
Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d),
228-
array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d)
226+
W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d)
229227
Bures Wasserstein distance
230228
log : dict
231229
log dictionary return only if log==True in parameters
232230
231+
233232
.. _references-bures-wasserstein-distance:
234233
References
235234
----------
236-
237-
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
238-
Transport", 2018.
235+
.. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science.
236+
Foundations and Trends® in Machine Learning, 11(5-6), 355-607.
239237
"""
240238
Cs, Ct = list_to_array(Cs, Ct)
241239

@@ -276,7 +274,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
276274
r"""Return Bures Wasserstein distance between samples.
277275
278276
The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`,
279-
as discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
277+
as discussed in remark 2.31 :ref:`[15] <references-bures-wasserstein-distance>`.
280278
281279
.. math::
282280
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
@@ -301,9 +299,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
301299
302300
Returns
303301
-------
304-
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d),
305-
mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d),
306-
array-like (n,m) if ms of shape (n,d) and mt of shape (m,d)
302+
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d)
307303
Bures Wasserstein distance
308304
log : dict
309305
log dictionary return only if log==True in parameters
@@ -313,8 +309,8 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
313309
References
314310
----------
315311
316-
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
317-
Transport", 2018.
312+
.. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science.
313+
Foundations and Trends® in Machine Learning, 11(5-6), 355-607.
318314
"""
319315
ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
320316
nx = get_backend(ms, mt, Cs, Ct)
@@ -455,7 +451,7 @@ def bures_barycenter_fixpoint(
455451
:ref:`[16] <references-OT-bures-barycenter-fixed-point>` by solving
456452
457453
.. math::
458-
\Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big).
454+
\Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{++}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big).
459455
460456
The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)`
461457
where :math:`\Sigma_b` is solution of the following fixed-point algorithm:
@@ -699,7 +695,7 @@ def bures_wasserstein_barycenter(
699695
.. math::
700696
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
701697
702-
We propose two solvers: one based on solving the previous fixed-point problem [1]. Another based on
698+
We propose two solvers: one based on solving the previous fixed-point problem [16]. Another based on
703699
gradient descent in the Bures-Wasserstein space [74,75].
704700
705701
Parameters
@@ -926,9 +922,8 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False):
926922
.. _references-gaussien_gromov_wasserstein_distance:
927923
References
928924
----------
929-
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
930-
distances between Gaussian distributions. Journal of Applied Probability,
931-
59(4), 1178-1198.
925+
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions.
926+
Journal of Applied Probability, 59(4), 1178-1198.
932927
"""
933928

934929
nx = get_backend(Cov_s, Cov_t)
@@ -990,9 +985,9 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, wt=None, log
990985
.. _references-gaussien_gromov_wasserstein:
991986
References
992987
----------
993-
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
994-
distances between Gaussian distributions. Journal of Applied Probability,
995-
59(4), 1178-1198.
988+
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022).
989+
Gromov–Wasserstein distances between Gaussian distributions.
990+
Journal of Applied Probability, 59(4), 1178-1198.
996991
"""
997992
xs, xt = list_to_array(xs, xt)
998993
nx = get_backend(xs, xt)
@@ -1058,9 +1053,9 @@ def gaussian_gromov_wasserstein_mapping(
10581053
.. _references-gaussien_gromov_wasserstein_mapping:
10591054
References
10601055
----------
1061-
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
1062-
distances between Gaussian distributions. Journal of Applied Probability,
1063-
59(4), 1178-1198.
1056+
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022).
1057+
Gromov–Wasserstein distances between Gaussian distributions.
1058+
Journal of Applied Probability, 59(4), 1178-1198.
10641059
"""
10651060

10661061
nx = get_backend(mu_s, mu_t, Cov_s, Cov_t)
@@ -1149,12 +1144,13 @@ def empirical_gaussian_gromov_wasserstein_mapping(
11491144
b : (1, dt) array-like
11501145
bias
11511146
1147+
11521148
.. _references-empirical_gaussian_gromov_wasserstein_mapping:
11531149
References
11541150
----------
1155-
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
1156-
distances between Gaussian distributions. Journal of Applied Probability,
1157-
59(4), 1178-1198.
1151+
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022).
1152+
Gromov–Wasserstein distances between Gaussian distributions.
1153+
Journal of Applied Probability, 59(4), 1178-1198.
11581154
"""
11591155

11601156
xs, xt = list_to_array(xs, xt)

test/test_gaussian.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,13 @@ def test_bures_wasserstein_distance(nx):
9292
msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
9393
Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)
9494
Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False)
95+
Wb2 = ot.gaussian.bures_distance(Csb, Ctb, log=False)
9596

9697
np.testing.assert_allclose(
9798
nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2
9899
)
99100
np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
101+
np.testing.assert_allclose(0, Wb2, rtol=1e-2, atol=1e-2)
100102

101103

102104
def test_bures_wasserstein_distance_batch(nx):

0 commit comments

Comments
 (0)