Skip to content

Commit 8490196

Browse files
authored
[MRG] Fix bug in regularized OTDA l1lp with log (#413)
* correct bug in DA l1lp with log * better tests and speedup with smaller dataset size * remove jax for log test * remove trndorflow for log test * pep8!
1 parent ac830dd commit 8490196

File tree

3 files changed

+53
-25
lines changed

3 files changed

+53
-25
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ roughly 2^31) (PR #381)
2727
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
2828
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
2929
- Fixed weak optimal transport docstring (Issue #404, PR #410)
30-
30+
- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
31+
PR #413)
3132

3233
## 0.8.2
3334

ot/da.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
126126
W = nx.zeros(M.shape, type_as=M)
127127
for cpt in range(numItermax):
128128
Mreg = M + eta * W
129-
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
130-
stopThr=stopInnerThr)
129+
if log:
130+
transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
131+
stopThr=stopInnerThr, log=True)
132+
else:
133+
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
134+
stopThr=stopInnerThr)
131135
# the transport has been computed. Check if classes are really
132136
# separated
133137
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
136140
majs = p * ((majs + epsilon) ** (p - 1))
137141
W[indices_labels[i]] = majs
138142

139-
return transp
143+
if log:
144+
return transp, log
145+
else:
146+
return transp
140147

141148

142149
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,

test/test_da.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,34 @@ def test_class_jax_tf():
4242
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
4343

4444

45+
@pytest.skip_backend("jax")
46+
@pytest.skip_backend("tf")
47+
@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
48+
def test_log_da(nx, class_to_test):
49+
50+
ns = 50
51+
nt = 50
52+
53+
Xs, ys = make_data_classif('3gauss', ns)
54+
Xt, yt = make_data_classif('3gauss2', nt)
55+
56+
Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
57+
58+
otda = class_to_test(log=True)
59+
60+
# test its computed
61+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
62+
assert hasattr(otda, "log_")
63+
64+
4565
@pytest.skip_backend("jax")
4666
@pytest.skip_backend("tf")
4767
def test_sinkhorn_lpl1_transport_class(nx):
4868
"""test_sinkhorn_transport
4969
"""
5070

51-
ns = 150
52-
nt = 200
71+
ns = 50
72+
nt = 50
5373

5474
Xs, ys = make_data_classif('3gauss', ns)
5575
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
136156
"""
137157

138158
ns = 50
139-
nt = 100
159+
nt = 50
140160

141161
Xs, ys = make_data_classif('3gauss', ns)
142162
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
230250
"""test_sinkhorn_transport
231251
"""
232252

233-
ns = 150
234-
nt = 200
253+
ns = 50
254+
nt = 50
235255

236256
Xs, ys = make_data_classif('3gauss', ns)
237257
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
323343
"""test_sinkhorn_transport
324344
"""
325345

326-
ns = 150
327-
nt = 200
346+
ns = 50
347+
nt = 50
328348

329349
Xs, ys = make_data_classif('3gauss', ns)
330350
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
402422
"""test_sinkhorn_transport
403423
"""
404424

405-
ns = 150
406-
nt = 200
425+
ns = 50
426+
nt = 50
407427

408428
Xs, ys = make_data_classif('3gauss', ns)
409429
Xt, yt = make_data_classif('3gauss2', nt)
@@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
558578
@pytest.skip_backend("jax")
559579
@pytest.skip_backend("tf")
560580
def test_linear_mapping(nx):
561-
ns = 150
562-
nt = 200
581+
ns = 50
582+
nt = 50
563583

564584
Xs, ys = make_data_classif('3gauss', ns)
565585
Xt, yt = make_data_classif('3gauss2', nt)
@@ -579,8 +599,8 @@ def test_linear_mapping(nx):
579599
@pytest.skip_backend("jax")
580600
@pytest.skip_backend("tf")
581601
def test_linear_mapping_class(nx):
582-
ns = 150
583-
nt = 200
602+
ns = 50
603+
nt = 50
584604

585605
Xs, ys = make_data_classif('3gauss', ns)
586606
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
609629
"""test_jcpot_transport
610630
"""
611631

612-
ns1 = 150
613-
ns2 = 150
614-
nt = 200
632+
ns1 = 50
633+
ns2 = 50
634+
nt = 50
615635

616636
Xs1, ys1 = make_data_classif('3gauss', ns1)
617637
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
681701
"""test_jcpot_barycenter
682702
"""
683703

684-
ns1 = 150
685-
ns2 = 150
686-
nt = 200
704+
ns1 = 50
705+
ns2 = 50
706+
nt = 50
687707

688708
sigma = 0.1
689709
np.random.seed(1985)
@@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
713733
def test_emd_laplace_class(nx):
714734
"""test_emd_laplace_transport
715735
"""
716-
ns = 150
717-
nt = 200
736+
ns = 50
737+
nt = 50
718738

719739
Xs, ys = make_data_classif('3gauss', ns)
720740
Xt, yt = make_data_classif('3gauss2', nt)

0 commit comments

Comments
 (0)