Skip to content

Commit c64906b

Browse files
authored
Merge pull request #97 from hichamjanati/fix_mismatch_error_94
[MRG] Fix mismatch error in stabilized sinkhorn
2 parents 0063cb8 + a507556 commit c64906b

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

ot/bregman.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,14 @@ def get_Gamma(alpha, beta, u, v):
765765

766766
cpt = cpt + 1
767767

768-
# print('err=',err,' cpt=',cpt)
769768
if log:
770-
log['logu'] = alpha / reg + np.log(u)
771-
log['logv'] = beta / reg + np.log(v)
769+
if nbb:
770+
alpha = alpha[:, None]
771+
beta = beta[:, None]
772+
logu = alpha / reg + np.log(u)
773+
logv = beta / reg + np.log(v)
774+
log['logu'] = logu
775+
log['logv'] = logv
772776
log['alpha'] = alpha + reg * np.log(u)
773777
log['beta'] = beta + reg * np.log(v)
774778
log['warmstart'] = (log['alpha'], log['beta'])

test/test_bregman.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence():
254254
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
255255
np.testing.assert_allclose(
256256
emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
257+
258+
259+
def test_stabilized_vs_sinkhorn_multidim():
260+
# test if stable version matches sinkhorn
261+
# for multidimensional inputs
262+
n = 100
263+
264+
# Gaussian distributions
265+
a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
266+
b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
267+
b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
268+
269+
# creating matrix A containing all distributions
270+
b = np.vstack((b1, b2)).T
271+
272+
M = ot.utils.dist0(n)
273+
M /= np.median(M)
274+
epsilon = 0.1
275+
G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
276+
method="sinkhorn_stabilized",
277+
log=True)
278+
G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
279+
method="sinkhorn", log=True)
280+
281+
np.testing.assert_allclose(G, G2)

0 commit comments

Comments
 (0)