Skip to content

Commit 42a501c

Browse files
committed
add test sinkhorn+log
1 parent 90d04e0 commit 42a501c

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

ot/bregman.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def sink():
120120
print('Warning : unknown method using classic Sinkhorn Knopp')
121121

122122
def sink():
123-
return sinkhorn_knopp(a, b, M, reg, **kwargs)
123+
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
124+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
124125

125126
return sink()
126127

@@ -499,6 +500,15 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
499500
500501
"""
501502

503+
a = np.asarray(a, dtype=np.float64)
504+
b = np.asarray(b, dtype=np.float64)
505+
M = np.asarray(M, dtype=np.float64)
506+
507+
if len(a) == 0:
508+
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
509+
if len(b) == 0:
510+
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
511+
502512
n = a.shape[0]
503513
m = b.shape[0]
504514

@@ -514,7 +524,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
514524
viol = G.sum(1) - a
515525
viol_2 = G.sum(0) - b
516526
stopThr_val = 1
527+
517528
if log:
529+
log = dict()
518530
log['u'] = u
519531
log['v'] = v
520532

test/test_bregman.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,31 @@ def test_sinkhorn_variants():
8181
print(G0, G_green)
8282

8383

84+
def test_sinkhorn_variants_log():
85+
# test sinkhorn
86+
n = 100
87+
rng = np.random.RandomState(0)
88+
89+
x = rng.randn(n, 2)
90+
u = ot.utils.unif(n)
91+
92+
M = ot.dist(x, x)
93+
94+
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
95+
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
96+
Ges, loges = ot.sinkhorn(
97+
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
98+
Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
99+
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
100+
101+
# check values
102+
np.testing.assert_allclose(G0, Gs, atol=1e-05)
103+
np.testing.assert_allclose(G0, Ges, atol=1e-05)
104+
np.testing.assert_allclose(G0, Gerr)
105+
np.testing.assert_allclose(G0, G_green, atol=1e-5)
106+
print(G0, G_green)
107+
108+
84109
def test_bary():
85110

86111
n_bins = 100 # nb bins

0 commit comments

Comments
 (0)