@@ -200,14 +200,77 @@ def empirical_bures_wasserstein_mapping(
200
200
return A , b
201
201
202
202
203
+ def bures_distance (Cs , Ct , log = False ):
204
+ r"""Return Bures distance.
205
+
206
+ The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
207
+ given by:
208
+
209
+ .. math::
210
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
211
+
212
+ Parameters
213
+ ----------
214
+ Cs : array-like (d,d) or (n,d,d)
215
+ covariance of the source distribution
216
+ Ct : array-like (d,d) or (m,d,d)
217
+ covariance of the target distribution
218
+ log : bool, optional
219
+ record log if True
220
+
221
+
222
+ Returns
223
+ -------
224
+ W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d),
225
+ Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d),
226
+ array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d)
227
+ Bures Wasserstein distance
228
+ log : dict
229
+ log dictionary return only if log==True in parameters
230
+
231
+ .. _references-bures-wasserstein-distance:
232
+ References
233
+ ----------
234
+
235
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
236
+ Transport", 2018.
237
+ """
238
+ Cs , Ct = list_to_array (Cs , Ct )
239
+ nx = get_backend (Cs , Ct )
240
+
241
+ Cs12 = nx .sqrtm (Cs )
242
+
243
+ if len (Cs .shape ) == 2 and len (Ct .shape ) == 2 :
244
+ # Return float
245
+ bw2 = nx .trace (Cs + Ct - 2 * nx .sqrtm (dots (Cs12 , Ct , Cs12 )))
246
+ elif len (Cs .shape ) == 2 :
247
+ # Return shape (m,)
248
+ M = nx .einsum ("ij, mjk, kl -> mil" , Cs12 , Ct , Cs12 )
249
+ bw2 = nx .trace (Cs [None ] + Ct - 2 * nx .sqrtm (M ))
250
+ elif len (Ct .shape ) == 2 :
251
+ # Return shape (n,)
252
+ M = nx .einsum ("nij, jk, nkl -> nil" , Cs12 , Ct , Cs12 )
253
+ bw2 = nx .trace (Cs + Ct [None ] - 2 * nx .sqrtm (M ))
254
+ else :
255
+ # Return shape (n,m)
256
+ M = nx .einsum ("nij, mjk, nkl -> nmil" , Cs12 , Ct , Cs12 )
257
+ bw2 = nx .trace (Cs [:, None ] + Ct [None ] - 2 * nx .sqrtm (M ))
258
+
259
+ W = nx .sqrt (nx .maximum (bw2 , 0 ))
260
+
261
+ if log :
262
+ log = {}
263
+ log ["Cs12" ] = Cs12
264
+ return W , log
265
+ else :
266
+ return W
267
+
268
+
203
269
def bures_wasserstein_distance (ms , mt , Cs , Ct , log = False ):
204
270
r"""Return Bures Wasserstein distance between samples.
205
271
206
- The function estimates the Bures-Wasserstein distance between two
207
- empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
208
- discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
209
-
210
- The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2`
272
+ The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`,
273
+ as discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
211
274
212
275
.. math::
213
276
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
@@ -230,7 +293,6 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
230
293
log : bool, optional
231
294
record log if True
232
295
233
-
234
296
Returns
235
297
-------
236
298
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d),
@@ -251,29 +313,38 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
251
313
ms , mt , Cs , Ct = list_to_array (ms , mt , Cs , Ct )
252
314
nx = get_backend (ms , mt , Cs , Ct )
253
315
254
- Cs12 = nx .sqrtm (Cs )
316
+ assert (
317
+ ms .shape [0 ] == Cs .shape [0 ]
318
+ ), "Source Gaussians has different amount of components"
319
+
320
+ assert (
321
+ mt .shape [0 ] == Ct .shape [0 ]
322
+ ), "Target Gaussians has different amount of components"
323
+
324
+ assert (
325
+ ms .shape [- 1 ] == mt .shape [- 1 ] == Cs .shape [- 1 ] == Ct .shape [- 1 ]
326
+ ), "All Gaussian must have the same dimension"
327
+
328
+ if log :
329
+ bw , log_dict = bures_distance (Cs , Ct , log )
330
+ Cs12 = log_dict ["Cs12" ]
331
+ else :
332
+ bw = bures_distance (Cs , Ct )
255
333
256
334
if len (ms .shape ) == 1 and len (mt .shape ) == 1 :
257
335
# Return float
258
336
squared_dist_m = nx .norm (ms - mt ) ** 2
259
- B = nx .trace (Cs + Ct - 2 * nx .sqrtm (dots (Cs12 , Ct , Cs12 )))
260
337
elif len (ms .shape ) == 1 :
261
338
# Return shape (m,)
262
- M = nx .einsum ("ij, mjk, kl -> mil" , Cs12 , Ct , Cs12 )
263
- B = nx .trace (Cs [None ] + Ct - 2 * nx .sqrtm (M ))
264
339
squared_dist_m = nx .norm (ms [None ] - mt , axis = - 1 ) ** 2
265
340
elif len (mt .shape ) == 1 :
266
341
# Return shape (n,)
267
- M = nx .einsum ("nij, jk, nkl -> nil" , Cs12 , Ct , Cs12 )
268
- B = nx .trace (Cs + Ct [None ] - 2 * nx .sqrtm (M ))
269
342
squared_dist_m = nx .norm (ms - mt [None ], axis = - 1 ) ** 2
270
343
else :
271
344
# Return shape (n,m)
272
- M = nx .einsum ("nij, mjk, nkl -> nmil" , Cs12 , Ct , Cs12 )
273
- B = nx .trace (Cs [:, None ] + Ct [None ] - 2 * nx .sqrtm (M ))
274
345
squared_dist_m = nx .norm (ms [:, None ] - mt [None ], axis = - 1 ) ** 2
275
346
276
- W = nx .sqrt (nx .maximum (squared_dist_m + B , 0 ))
347
+ W = nx .sqrt (nx .maximum (squared_dist_m + bw ** 2 , 0 ))
277
348
278
349
if log :
279
350
log = {}
0 commit comments