Skip to content

Commit 0b20759

Browse files
committed
Add ot.gaussian.bures
1 parent bad385f commit 0b20759

File tree

1 file changed

+86
-15
lines changed

1 file changed

+86
-15
lines changed

Diff for: ot/gaussian.py

+86-15
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,77 @@ def empirical_bures_wasserstein_mapping(
200200
return A, b
201201

202202

203+
def bures_distance(Cs, Ct, log=False):
204+
r"""Return Bures distance.
205+
206+
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
207+
given by:
208+
209+
.. math::
210+
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
211+
212+
Parameters
213+
----------
214+
Cs : array-like (d,d) or (n,d,d)
215+
covariance of the source distribution
216+
Ct : array-like (d,d) or (m,d,d)
217+
covariance of the target distribution
218+
log : bool, optional
219+
record log if True
220+
221+
222+
Returns
223+
-------
224+
W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d),
225+
Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d),
226+
array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d)
227+
Bures Wasserstein distance
228+
log : dict
229+
log dictionary return only if log==True in parameters
230+
231+
.. _references-bures-wasserstein-distance:
232+
References
233+
----------
234+
235+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
236+
Transport", 2018.
237+
"""
238+
Cs, Ct = list_to_array(Cs, Ct)
239+
nx = get_backend(Cs, Ct)
240+
241+
Cs12 = nx.sqrtm(Cs)
242+
243+
if len(Cs.shape) == 2 and len(Ct.shape) == 2:
244+
# Return float
245+
bw2 = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
246+
elif len(Cs.shape) == 2:
247+
# Return shape (m,)
248+
M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12)
249+
bw2 = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M))
250+
elif len(Ct.shape) == 2:
251+
# Return shape (n,)
252+
M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12)
253+
bw2 = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M))
254+
else:
255+
# Return shape (n,m)
256+
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
257+
bw2 = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
258+
259+
W = nx.sqrt(nx.maximum(bw2, 0))
260+
261+
if log:
262+
log = {}
263+
log["Cs12"] = Cs12
264+
return W, log
265+
else:
266+
return W
267+
268+
203269
def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
204270
r"""Return Bures Wasserstein distance between samples.
205271
206-
The function estimates the Bures-Wasserstein distance between two
207-
empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
208-
discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
209-
210-
The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2`
272+
The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`,
273+
as discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
211274
212275
.. math::
213276
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
@@ -230,7 +293,6 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
230293
log : bool, optional
231294
record log if True
232295
233-
234296
Returns
235297
-------
236298
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d),
@@ -251,29 +313,38 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
251313
ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
252314
nx = get_backend(ms, mt, Cs, Ct)
253315

254-
Cs12 = nx.sqrtm(Cs)
316+
assert (
317+
ms.shape[0] == Cs.shape[0]
318+
), "Source Gaussians has different amount of components"
319+
320+
assert (
321+
mt.shape[0] == Ct.shape[0]
322+
), "Target Gaussians has different amount of components"
323+
324+
assert (
325+
ms.shape[-1] == mt.shape[-1] == Cs.shape[-1] == Ct.shape[-1]
326+
), "All Gaussian must have the same dimension"
327+
328+
if log:
329+
bw, log_dict = bures_distance(Cs, Ct, log)
330+
Cs12 = log_dict["Cs12"]
331+
else:
332+
bw = bures_distance(Cs, Ct)
255333

256334
if len(ms.shape) == 1 and len(mt.shape) == 1:
257335
# Return float
258336
squared_dist_m = nx.norm(ms - mt) ** 2
259-
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
260337
elif len(ms.shape) == 1:
261338
# Return shape (m,)
262-
M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12)
263-
B = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M))
264339
squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2
265340
elif len(mt.shape) == 1:
266341
# Return shape (n,)
267-
M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12)
268-
B = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M))
269342
squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2
270343
else:
271344
# Return shape (n,m)
272-
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
273-
B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
274345
squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2
275346

276-
W = nx.sqrt(nx.maximum(squared_dist_m + B, 0))
347+
W = nx.sqrt(nx.maximum(squared_dist_m + bw**2, 0))
277348

278349
if log:
279350
log = {}

0 commit comments

Comments
 (0)