@@ -200,7 +200,7 @@ def empirical_bures_wasserstein_mapping(
200
200
return A , b
201
201
202
202
203
- def bures_distance (Cs , Ct , log = False , nx = None ):
203
+ def bures_distance (Cs , Ct , paired = False , log = False , nx = None ):
204
204
r"""Return Bures distance.
205
205
206
206
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
@@ -215,6 +215,8 @@ def bures_distance(Cs, Ct, log=False, nx=None):
215
215
covariance of the source distribution
216
216
Ct : array-like (d,d) or (m,d,d)
217
217
covariance of the target distribution
218
+ paired: bool, optional
219
+ if True and n==m, return the paired distances and crossed distance otherwise
218
220
log : bool, optional
219
221
record log if True
220
222
nx : module, optional
@@ -223,7 +225,7 @@ def bures_distance(Cs, Ct, log=False, nx=None):
223
225
224
226
Returns
225
227
-------
226
- W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), array-like (n,m ) if Cs of shape (n,d, d) and mt of shape (m,d,d)
228
+ W : float if Cs and Cd of shape (d,d), array-like (n,m ) if Cs of shape (n,d,d) and Ct of shape (m,d,d), array-like (n,) if Cs and Ct of shape (n, d, d) and paired is True
227
229
Bures Wasserstein distance
228
230
log : dict
229
231
log dictionary return only if log==True in parameters
@@ -247,18 +249,18 @@ def bures_distance(Cs, Ct, log=False, nx=None):
247
249
if len (Cs .shape ) == 2 and len (Ct .shape ) == 2 :
248
250
# Return float
249
251
bw2 = nx .trace (Cs + Ct - 2 * nx .sqrtm (dots (Cs12 , Ct , Cs12 )))
250
- elif len (Cs .shape ) == 2 :
251
- # Return shape (m,)
252
- M = nx .einsum ("ij, mjk, kl -> mil" , Cs12 , Ct , Cs12 )
253
- bw2 = nx .trace (Cs [None ] + Ct - 2 * nx .sqrtm (M ))
254
- elif len (Ct .shape ) == 2 :
255
- # Return shape (n,)
256
- M = nx .einsum ("nij, jk, nkl -> nil" , Cs12 , Ct , Cs12 )
257
- bw2 = nx .trace (Cs + Ct [None ] - 2 * nx .sqrtm (M ))
258
252
else :
259
- # Return shape (n,m)
260
- M = nx .einsum ("nij, mjk, nkl -> nmil" , Cs12 , Ct , Cs12 )
261
- bw2 = nx .trace (Cs [:, None ] + Ct [None ] - 2 * nx .sqrtm (M ))
253
+ assert (
254
+ len (Cs .shape ) == 3 and len (Ct .shape ) == 3
255
+ ), "Both Cs and Ct should be batched"
256
+ if paired and len (Cs ) == len (Ct ):
257
+ # Return shape (n,)
258
+ M = nx .einsum ("nij, njk, nkl -> nil" , Cs12 , Ct , Cs12 )
259
+ bw2 = nx .trace (Cs + Ct - 2 * nx .sqrtm (M ))
260
+ else :
261
+ # Return shape (n,m)
262
+ M = nx .einsum ("nij, mjk, nkl -> nmil" , Cs12 , Ct , Cs12 )
263
+ bw2 = nx .trace (Cs [:, None ] + Ct [None ] - 2 * nx .sqrtm (M ))
262
264
263
265
W = nx .sqrt (nx .maximum (bw2 , 0 ))
264
266
@@ -270,7 +272,7 @@ def bures_distance(Cs, Ct, log=False, nx=None):
270
272
return W
271
273
272
274
273
- def bures_wasserstein_distance (ms , mt , Cs , Ct , log = False ):
275
+ def bures_wasserstein_distance (ms , mt , Cs , Ct , paired = False , log = False ):
274
276
r"""Return Bures Wasserstein distance between samples.
275
277
276
278
The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`,
@@ -294,12 +296,14 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
294
296
covariance of the source distribution
295
297
Ct : array-like (d,d) or (m,d,d)
296
298
covariance of the target distribution
299
+ paired: bool, optional
300
+ if True and n==m, return the paired distances and crossed distance otherwise
297
301
log : bool, optional
298
302
record log if True
299
303
300
304
Returns
301
305
-------
302
- W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m ) if ms of shape (n,d) and mt of shape (m,d)
306
+ W : float if ms and md of shape (d,), array-like (n,m ) if ms of shape (n,d) and mt of shape (m,d), array-like (n,) if ms and mt of shape (n,d) and paired is True
303
307
Bures Wasserstein distance
304
308
log : dict
305
309
log dictionary return only if log==True in parameters
@@ -328,23 +332,24 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
328
332
), "All Gaussian must have the same dimension"
329
333
330
334
if log :
331
- bw , log_dict = bures_distance (Cs , Ct , log = log , nx = nx )
335
+ bw , log_dict = bures_distance (Cs , Ct , paired = paired , log = log , nx = nx )
332
336
Cs12 = log_dict ["Cs12" ]
333
337
else :
334
- bw = bures_distance (Cs , Ct , nx = nx )
338
+ bw = bures_distance (Cs , Ct , paired = paired , nx = nx )
335
339
336
340
if len (ms .shape ) == 1 and len (mt .shape ) == 1 :
337
341
# Return float
338
342
squared_dist_m = nx .norm (ms - mt ) ** 2
339
- elif len (ms .shape ) == 1 :
340
- # Return shape (m,)
341
- squared_dist_m = nx .norm (ms [None ] - mt , axis = - 1 ) ** 2
342
- elif len (mt .shape ) == 1 :
343
- # Return shape (n,)
344
- squared_dist_m = nx .norm (ms - mt [None ], axis = - 1 ) ** 2
345
343
else :
346
- # Return shape (n,m)
347
- squared_dist_m = nx .norm (ms [:, None ] - mt [None ], axis = - 1 ) ** 2
344
+ assert (
345
+ len (ms .shape ) == 2 and len (mt .shape ) == 2
346
+ ), "Both ms and mt should be batched"
347
+ if paired and len (ms .shape ) == len (mt .shape ):
348
+ # Return shape (n,)
349
+ squared_dist_m = nx .norm (ms - mt , axis = - 1 ) ** 2
350
+ else :
351
+ # Return shape (n,m)
352
+ squared_dist_m = nx .norm (ms [:, None ] - mt [None ], axis = - 1 ) ** 2
348
353
349
354
W = nx .sqrt (nx .maximum (squared_dist_m + bw ** 2 , 0 ))
350
355
@@ -882,12 +887,14 @@ def empirical_bures_wasserstein_barycenter(
882
887
nx .dot ((X [i ] * w [i ]).T , X [i ]) / nx .sum (w [i ]) + reg * nx .eye (d [i ], type_as = X [i ])
883
888
for i in range (k )
884
889
]
885
- m = nx .stack (m , axis = 0 )
890
+ m = nx .stack (m , axis = 0 )[:, 0 ]
886
891
C = nx .stack (C , axis = 0 )
892
+
887
893
if log :
888
894
mb , Cb , log = bures_wasserstein_barycenter (
889
895
m , C , weights = weights , num_iter = num_iter , eps = eps , log = log
890
896
)
897
+
891
898
return mb , Cb , log
892
899
else :
893
900
mb , Cb = bures_wasserstein_barycenter (
0 commit comments