Skip to content

Commit 2d4d0b4

Browse files
committed
solving log issues to avoid errors and adding further tests
1 parent 0930223 commit 2d4d0b4

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

ot/da.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):
13151315
13161316
Attributes
13171317
----------
1318-
coupling_ : the optimal coupling
1318+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1319+
The optimal coupling
1320+
log_ : dictionary
1321+
The dictionary of log, empty dic if parameter log is not True
13191322
13201323
References
13211324
----------
@@ -1367,11 +1370,18 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
13671370
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
13681371

13691372
# coupling estimation
1370-
self.coupling_ = sinkhorn(
1373+
returned_ = sinkhorn(
13711374
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
13721375
numItermax=self.max_iter, stopThr=self.tol,
13731376
verbose=self.verbose, log=self.log)
13741377

1378+
# deal with the value of log
1379+
if self.log:
1380+
self.coupling_, self.log_ = returned_
1381+
else:
1382+
self.coupling_ = returned_
1383+
self.log_ = dict()
1384+
13751385
return self
13761386

13771387

@@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):
14001410
14011411
Attributes
14021412
----------
1403-
coupling_ : the optimal coupling
1413+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1414+
The optimal coupling
14041415
14051416
References
14061417
----------
@@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
14751486
The number of iteration in the inner loop
14761487
verbose : int, optional (default=0)
14771488
Controls the verbosity of the optimization algorithm
1478-
log : int, optional (default=0)
1479-
Controls the logs of the optimization algorithm
14801489
limit_max: float, optional (defaul=np.infty)
14811490
Controls the semi supervised mode. Transport between labeled source
14821491
and target samples of different classes will exhibit an infinite cost
14831492
14841493
Attributes
14851494
----------
1486-
coupling_ : the optimal coupling
1495+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1496+
The optimal coupling
14871497
14881498
References
14891499
----------
@@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):
15001510

15011511
def __init__(self, reg_e=1., reg_cl=0.1,
15021512
max_iter=10, max_inner_iter=200,
1503-
tol=10e-9, verbose=False, log=False,
1513+
tol=10e-9, verbose=False,
15041514
metric="sqeuclidean",
15051515
distribution_estimation=distribution_estimation_uniform,
15061516
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1511,7 +1521,6 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15111521
self.max_inner_iter = max_inner_iter
15121522
self.tol = tol
15131523
self.verbose = verbose
1514-
self.log = log
15151524
self.metric = metric
15161525
self.distribution_estimation = distribution_estimation
15171526
self.out_of_sample_map = out_of_sample_map
@@ -1544,7 +1553,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15441553
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
15451554
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
15461555
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
1547-
verbose=self.verbose, log=self.log)
1556+
verbose=self.verbose)
15481557

15491558
return self
15501559

@@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):
15841593
15851594
Attributes
15861595
----------
1587-
coupling_ : the optimal coupling
1596+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1597+
The optimal coupling
1598+
log_ : dictionary
1599+
The dictionary of log, empty dic if parameter log is not True
15881600
15891601
References
15901602
----------
@@ -1641,12 +1653,19 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
16411653

16421654
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
16431655

1644-
self.coupling_ = sinkhorn_l1l2_gl(
1656+
returned_ = sinkhorn_l1l2_gl(
16451657
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
16461658
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
16471659
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
16481660
verbose=self.verbose, log=self.log)
16491661

1662+
# deal with the value of log
1663+
if self.log:
1664+
self.coupling_, self.log_ = returned_
1665+
else:
1666+
self.coupling_ = returned_
1667+
self.log_ = dict()
1668+
16501669
return self
16511670

16521671

@@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):
16831702
16841703
Attributes
16851704
----------
1686-
coupling_ : array-like, shape (n_source_samples, n_features)
1705+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
16871706
The optimal coupling
16881707
mapping_ : array-like, shape (n_features (+ 1), n_features)
16891708
(if bias) for kernel == linear
16901709
The associated mapping
1691-
16921710
array-like, shape (n_source_samples (+ 1), n_features)
16931711
(if bias) for kernel == gaussian
1712+
log_ : dictionary
1713+
The dictionary of log, empty dic if parameter log is not True
16941714
16951715
References
16961716
----------
@@ -1745,19 +1765,26 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
17451765
self.Xt = Xt
17461766

17471767
if self.kernel == "linear":
1748-
self.coupling_, self.mapping_ = joint_OT_mapping_linear(
1768+
returned_ = joint_OT_mapping_linear(
17491769
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
17501770
verbose=self.verbose, verbose2=self.verbose2,
17511771
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
17521772
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
17531773

17541774
elif self.kernel == "gaussian":
1755-
self.coupling_, self.mapping_ = joint_OT_mapping_kernel(
1775+
returned_ = joint_OT_mapping_kernel(
17561776
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
17571777
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
17581778
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
17591779
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
17601780

1781+
# deal with the value of log
1782+
if self.log:
1783+
self.coupling_, self.mapping_, self.log_ = returned_
1784+
else:
1785+
self.coupling_, self.mapping_ = returned_
1786+
self.log_ = dict()
1787+
17611788
return self
17621789

17631790
def transform(self, Xs):

test/test_da.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class():
2626

2727
# test its computed
2828
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
29+
assert hasattr(clf, "cost_")
30+
assert hasattr(clf, "coupling_")
2931

3032
# test dimensions of coupling
3133
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class():
8991

9092
# test its computed
9193
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
94+
assert hasattr(clf, "cost_")
95+
assert hasattr(clf, "coupling_")
96+
assert hasattr(clf, "log_")
9297

9398
# test dimensions of coupling
9499
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class():
137142

138143
assert n_unsup != n_semisup, "semisupervised mode not working"
139144

145+
# check everything runs well with log=True
146+
clf = ot.da.SinkhornL1l2Transport(log=True)
147+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
148+
assert len(clf.log_.keys()) != 0
149+
140150

141151
def test_sinkhorn_transport_class():
142152
"""test_sinkhorn_transport
@@ -152,6 +162,9 @@ def test_sinkhorn_transport_class():
152162

153163
# test its computed
154164
clf.fit(Xs=Xs, Xt=Xt)
165+
assert hasattr(clf, "cost_")
166+
assert hasattr(clf, "coupling_")
167+
assert hasattr(clf, "log_")
155168

156169
# test dimensions of coupling
157170
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -200,6 +213,11 @@ def test_sinkhorn_transport_class():
200213

201214
assert n_unsup != n_semisup, "semisupervised mode not working"
202215

216+
# check everything runs well with log=True
217+
clf = ot.da.SinkhornTransport(log=True)
218+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
219+
assert len(clf.log_.keys()) != 0
220+
203221

204222
def test_emd_transport_class():
205223
"""test_sinkhorn_transport
@@ -215,6 +233,8 @@ def test_emd_transport_class():
215233

216234
# test its computed
217235
clf.fit(Xs=Xs, Xt=Xt)
236+
assert hasattr(clf, "cost_")
237+
assert hasattr(clf, "coupling_")
218238

219239
# test dimensions of coupling
220240
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -282,6 +302,9 @@ def test_mapping_transport_class():
282302
# check computation and dimensions if bias == False
283303
clf = ot.da.MappingTransport(kernel="linear", bias=False)
284304
clf.fit(Xs=Xs, Xt=Xt)
305+
assert hasattr(clf, "coupling_")
306+
assert hasattr(clf, "mapping_")
307+
assert hasattr(clf, "log_")
285308

286309
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
287310
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
@@ -369,6 +392,11 @@ def test_mapping_transport_class():
369392
# check that the oos method is working
370393
assert_equal(transp_Xs_new.shape, Xs_new.shape)
371394

395+
# check everything runs well with log=True
396+
clf = ot.da.MappingTransport(kernel="gaussian", log=True)
397+
clf.fit(Xs=Xs, Xt=Xt)
398+
assert len(clf.log_.keys()) != 0
399+
372400

373401
def test_otda():
374402

@@ -434,9 +462,8 @@ def test_otda():
434462

435463
# if __name__ == "__main__":
436464

437-
# test_otda()
438-
# test_sinkhorn_transport_class()
439-
# test_emd_transport_class()
440-
# test_sinkhorn_l1l2_transport_class()
441-
# test_sinkhorn_lpl1_transport_class()
442-
# test_mapping_transport_class()
465+
# test_sinkhorn_transport_class()
466+
# test_emd_transport_class()
467+
# test_sinkhorn_l1l2_transport_class()
468+
# test_sinkhorn_lpl1_transport_class()
469+
# test_mapping_transport_class()

0 commit comments

Comments
 (0)