Skip to content

Commit 3a7effc

Browse files
committed
change API bw
1 parent 3a7af81 commit 3a7effc

File tree

3 files changed

+63
-51
lines changed

3 files changed

+63
-51
lines changed

ot/backend.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,6 @@ def solve(self, a, b):
13641364

13651365
def trace(self, a):
13661366
return np.einsum("...ii", a)
1367-
# return np.trace(a)
13681367

13691368
def inv(self, a):
13701369
return scipy.linalg.inv(a)
@@ -1777,8 +1776,7 @@ def solve(self, a, b):
17771776
return jnp.linalg.solve(a, b)
17781777

17791778
def trace(self, a):
1780-
return jnp.einsum("...ii", a)
1781-
# return jnp.trace(a)
1779+
return jnp.diagonal(a, axis1=-2, axis2=-1).sum(-1)
17821780

17831781
def inv(self, a):
17841782
return jnp.linalg.inv(a)
@@ -2311,8 +2309,7 @@ def solve(self, a, b):
23112309
return torch.linalg.solve(a, b)
23122310

23132311
def trace(self, a):
2314-
return torch.einsum("...ii", a)
2315-
# return torch.trace(a)
2312+
return torch.diagonal(a, dim1=-2, dim2=-1).sum(-1)
23162313

23172314
def inv(self, a):
23182315
return torch.linalg.inv(a)
@@ -2726,8 +2723,7 @@ def solve(self, a, b):
27262723
return cp.linalg.solve(a, b)
27272724

27282725
def trace(self, a):
2729-
return cp.einsum("..ii", a)
2730-
# return cp.trace(a)
2726+
return cp.trace(a, axis1=-2, axis2=-1)
27312727

27322728
def inv(self, a):
27332729
return cp.linalg.inv(a)
@@ -3163,8 +3159,7 @@ def solve(self, a, b):
31633159
return tf.linalg.solve(a, b)
31643160

31653161
def trace(self, a):
3166-
return tf.einsum("...ii", a)
3167-
# return tf.linalg.trace(a)
3162+
return tf.linalg.trace(a)
31683163

31693164
def inv(self, a):
31703165
return tf.linalg.inv(a)

ot/gaussian.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def empirical_bures_wasserstein_mapping(
200200
return A, b
201201

202202

203-
def bures_distance(Cs, Ct, log=False, nx=None):
203+
def bures_distance(Cs, Ct, paired=False, log=False, nx=None):
204204
r"""Return Bures distance.
205205
206206
The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`,
@@ -215,6 +215,8 @@ def bures_distance(Cs, Ct, log=False, nx=None):
215215
covariance of the source distribution
216216
Ct : array-like (d,d) or (m,d,d)
217217
covariance of the target distribution
218+
paired: bool, optional
219+
if True and n==m, return the paired distances and crossed distance otherwise
218220
log : bool, optional
219221
record log if True
220222
nx : module, optional
@@ -223,7 +225,7 @@ def bures_distance(Cs, Ct, log=False, nx=None):
223225
224226
Returns
225227
-------
226-
W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d)
228+
W : float if Cs and Cd of shape (d,d), array-like (n,m) if Cs of shape (n,d,d) and Ct of shape (m,d,d), array-like (n,) if Cs and Ct of shape (n, d, d) and paired is True
227229
Bures Wasserstein distance
228230
log : dict
229231
log dictionary return only if log==True in parameters
@@ -247,18 +249,18 @@ def bures_distance(Cs, Ct, log=False, nx=None):
247249
if len(Cs.shape) == 2 and len(Ct.shape) == 2:
248250
# Return float
249251
bw2 = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
250-
elif len(Cs.shape) == 2:
251-
# Return shape (m,)
252-
M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12)
253-
bw2 = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M))
254-
elif len(Ct.shape) == 2:
255-
# Return shape (n,)
256-
M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12)
257-
bw2 = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M))
258252
else:
259-
# Return shape (n,m)
260-
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
261-
bw2 = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
253+
assert (
254+
len(Cs.shape) == 3 and len(Ct.shape) == 3
255+
), "Both Cs and Ct should be batched"
256+
if paired and len(Cs) == len(Ct):
257+
# Return shape (n,)
258+
M = nx.einsum("nij, njk, nkl -> nil", Cs12, Ct, Cs12)
259+
bw2 = nx.trace(Cs + Ct - 2 * nx.sqrtm(M))
260+
else:
261+
# Return shape (n,m)
262+
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
263+
bw2 = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
262264

263265
W = nx.sqrt(nx.maximum(bw2, 0))
264266

@@ -270,7 +272,7 @@ def bures_distance(Cs, Ct, log=False, nx=None):
270272
return W
271273

272274

273-
def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
275+
def bures_wasserstein_distance(ms, mt, Cs, Ct, paired=False, log=False):
274276
r"""Return Bures Wasserstein distance between samples.
275277
276278
The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`,
@@ -294,12 +296,14 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
294296
covariance of the source distribution
295297
Ct : array-like (d,d) or (m,d,d)
296298
covariance of the target distribution
299+
paired: bool, optional
300+
if True and n==m, return the paired distances and crossed distance otherwise
297301
log : bool, optional
298302
record log if True
299303
300304
Returns
301305
-------
302-
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d)
306+
W : float if ms and md of shape (d,), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d), array-like (n,) if ms and mt of shape (n,d) and paired is True
303307
Bures Wasserstein distance
304308
log : dict
305309
log dictionary return only if log==True in parameters
@@ -328,23 +332,24 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
328332
), "All Gaussian must have the same dimension"
329333

330334
if log:
331-
bw, log_dict = bures_distance(Cs, Ct, log=log, nx=nx)
335+
bw, log_dict = bures_distance(Cs, Ct, paired=paired, log=log, nx=nx)
332336
Cs12 = log_dict["Cs12"]
333337
else:
334-
bw = bures_distance(Cs, Ct, nx=nx)
338+
bw = bures_distance(Cs, Ct, paired=paired, nx=nx)
335339

336340
if len(ms.shape) == 1 and len(mt.shape) == 1:
337341
# Return float
338342
squared_dist_m = nx.norm(ms - mt) ** 2
339-
elif len(ms.shape) == 1:
340-
# Return shape (m,)
341-
squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2
342-
elif len(mt.shape) == 1:
343-
# Return shape (n,)
344-
squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2
345343
else:
346-
# Return shape (n,m)
347-
squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2
344+
assert (
345+
len(ms.shape) == 2 and len(mt.shape) == 2
346+
), "Both ms and mt should be batched"
347+
if paired and len(ms.shape) == len(mt.shape):
348+
# Return shape (n,)
349+
squared_dist_m = nx.norm(ms - mt, axis=-1) ** 2
350+
else:
351+
# Return shape (n,m)
352+
squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2
348353

349354
W = nx.sqrt(nx.maximum(squared_dist_m + bw**2, 0))
350355

@@ -882,12 +887,14 @@ def empirical_bures_wasserstein_barycenter(
882887
nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i])
883888
for i in range(k)
884889
]
885-
m = nx.stack(m, axis=0)
890+
m = nx.stack(m, axis=0)[:, 0]
886891
C = nx.stack(C, axis=0)
892+
887893
if log:
888894
mb, Cb, log = bures_wasserstein_barycenter(
889895
m, C, weights=weights, num_iter=num_iter, eps=eps, log=log
890896
)
897+
891898
return mb, Cb, log
892899
else:
893900
mb, Cb = bures_wasserstein_barycenter(

test/test_gaussian.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_bures_wasserstein_distance_batch(nx):
124124

125125
Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False)
126126

127+
# Test cross vs 1
127128
Wb2 = ot.gaussian.bures_wasserstein_distance(
128129
m[0, 0][None], m[1, 0][None], C[0][None], C[1][None]
129130
)
@@ -135,25 +136,34 @@ def test_bures_wasserstein_distance_batch(nx):
135136
np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5)
136137
np.testing.assert_equal(Wb2.shape, (2, 1))
137138

138-
Wb2 = ot.gaussian.bures_wasserstein_distance(
139-
m[0, 0][None], m[1, 0], C[0][None], C[1]
140-
)
141-
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5)
142-
np.testing.assert_equal(Wb2.shape, (1,))
143-
144-
Wb2 = ot.gaussian.bures_wasserstein_distance(
145-
m[0, 0], m[1, 0][None], C[0], C[1][None]
146-
)
147-
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5)
148-
np.testing.assert_equal(Wb2.shape, (1,))
149-
150139
Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C)
151140
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5)
152141
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5)
153142
np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5)
154143
np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5)
155144
np.testing.assert_equal(Wb2.shape, (2, 2))
156145

146+
# Test paired
147+
Wb3 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C, paired=True)
148+
np.testing.assert_allclose(0, nx.to_numpy(Wb3[0]), atol=1e-5)
149+
np.testing.assert_allclose(0, nx.to_numpy(Wb3[1]), atol=1e-5)
150+
151+
m_rev = np.zeros((k, 2))
152+
C_rev = np.zeros((k, 2, 2))
153+
m_rev[0] = m[1, 0]
154+
m_rev[1] = m[0, 0]
155+
C_rev[0] = C[1]
156+
C_rev[1] = C[0]
157+
158+
Wb3 = ot.gaussian.bures_wasserstein_distance(m_rev, m[:, 0], C_rev, C, paired=True)
159+
np.testing.assert_allclose(nx.to_numpy(Wb2)[0, 1], nx.to_numpy(Wb3)[0], atol=1e-5)
160+
np.testing.assert_allclose(nx.to_numpy(Wb2)[0, 1], nx.to_numpy(Wb3)[0], atol=1e-5)
161+
162+
with pytest.raises(AssertionError):
163+
Wb3 = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[:, 0], C[0], C)
164+
np.testing.assert_allclose(0, nx.to_numpy(Wb3[0]), atol=1e-5)
165+
np.testing.assert_allclose(0, nx.to_numpy(Wb3[1]), atol=1e-5)
166+
157167

158168
@pytest.mark.parametrize("bias", [True, False])
159169
def test_empirical_bures_wasserstein_distance(nx, bias):
@@ -205,7 +215,7 @@ def test_bures_wasserstein_barycenter(nx, method):
205215
m = np.array(m)
206216
C = np.array(C)
207217
X = nx.from_numpy(*X)
208-
m = nx.from_numpy(m)
218+
m = nx.from_numpy(m)[:, 0]
209219
C = nx.from_numpy(C)
210220

211221
mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(
@@ -256,7 +266,7 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx):
256266
m = np.array(m)
257267
C = np.array(C)
258268
X = nx.from_numpy(*X)
259-
m = nx.from_numpy(m)
269+
m = nx.from_numpy(m)[:, 0]
260270
C = nx.from_numpy(C)
261271

262272
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(
@@ -297,7 +307,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx, method):
297307
m = np.array(m)
298308
C = np.array(C)
299309
X = nx.from_numpy(*X)
300-
m = nx.from_numpy(m)
310+
m = nx.from_numpy(m)[:, 0]
301311
C = nx.from_numpy(C)
302312

303313
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(
@@ -346,7 +356,7 @@ def test_not_implemented_method(nx):
346356
m = np.array(m)
347357
C = np.array(C)
348358
X = nx.from_numpy(*X)
349-
m = nx.from_numpy(m)
359+
m = nx.from_numpy(m)[:, 0]
350360
C = nx.from_numpy(C)
351361

352362
not_implemented = "new_method"

0 commit comments

Comments
 (0)