Skip to content

Commit 9d4b786

Browse files
committed
fixes for travis, added test, minor nits
1 parent 0928668 commit 9d4b786

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

.travis.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ matrix:
1313
python: 3.5
1414
- os: linux
1515
sudo: required
16-
python: 3.6
16+
python: 3.6
1717
- os: linux
1818
sudo: required
1919
python: 2.7
2020
before_install:
2121
- ./.travis/before_install.sh
2222
before_script: # configure a headless display to test plot generation
2323
- "export DISPLAY=:99.0"
24-
- "sh -e /etc/init.d/xvfb start"
2524
- sleep 3 # give xvfb some time to start
2625
# command to install dependencies
2726
install:
@@ -30,6 +29,8 @@ install:
3029
- pip install flake8 pytest "pytest-cov<2.6"
3130
- pip install .
3231
# command to run tests + check syntax style
32+
services:
33+
- xvfb
3334
script:
3435
- python setup.py develop
3536
- flake8 examples/ ot/ test/

ot/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
18521852
"""
18531853

18541854
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
1855-
max_iter=10, tol=10e-9, verbose=False, log=False,
1855+
max_iter=10, tol=1e-9, verbose=False, log=False,
18561856
metric="sqeuclidean", norm=None,
18571857
distribution_estimation=distribution_estimation_uniform,
18581858
out_of_sample_map='ferradans', limit_max=10):

ot/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def cost_normalization(C, norm=None):
178178
The input cost matrix normalized according to given norm.
179179
"""
180180

181-
if norm == "median":
181+
if norm is None:
182+
pass
183+
elif norm == "median":
182184
C /= float(np.median(C))
183185
elif norm == "max":
184186
C /= float(np.max(C))

test/test_da.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,79 @@ def test_sinkhorn_transport_class():
245245
assert len(otda.log_.keys()) != 0
246246

247247

248+
def test_unbalanced_sinkhorn_transport_class():
249+
"""test_sinkhorn_transport
250+
"""
251+
252+
ns = 150
253+
nt = 200
254+
255+
Xs, ys = make_data_classif('3gauss', ns)
256+
Xt, yt = make_data_classif('3gauss2', nt)
257+
258+
otda = ot.da.UnbalancedSinkhornTransport()
259+
260+
# test its computed
261+
otda.fit(Xs=Xs, Xt=Xt)
262+
assert hasattr(otda, "cost_")
263+
assert hasattr(otda, "coupling_")
264+
assert hasattr(otda, "log_")
265+
266+
# test dimensions of coupling
267+
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
268+
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
269+
270+
# test margin constraints
271+
mu_s = unif(ns)
272+
mu_t = unif(nt)
273+
assert_allclose(
274+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
275+
assert_allclose(
276+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
277+
278+
# test transform
279+
transp_Xs = otda.transform(Xs=Xs)
280+
assert_equal(transp_Xs.shape, Xs.shape)
281+
282+
Xs_new, _ = make_data_classif('3gauss', ns + 1)
283+
transp_Xs_new = otda.transform(Xs_new)
284+
285+
# check that the oos method is working
286+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
287+
288+
# test inverse transform
289+
transp_Xt = otda.inverse_transform(Xt=Xt)
290+
assert_equal(transp_Xt.shape, Xt.shape)
291+
292+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
293+
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
294+
295+
# check that the oos method is working
296+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
297+
298+
# test fit_transform
299+
transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
300+
assert_equal(transp_Xs.shape, Xs.shape)
301+
302+
# test unsupervised vs semi-supervised mode
303+
otda_unsup = ot.da.SinkhornTransport()
304+
otda_unsup.fit(Xs=Xs, Xt=Xt)
305+
n_unsup = np.sum(otda_unsup.cost_)
306+
307+
otda_semi = ot.da.SinkhornTransport()
308+
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
309+
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
310+
n_semisup = np.sum(otda_semi.cost_)
311+
312+
# check that the cost matrix norms are indeed different
313+
assert n_unsup != n_semisup, "semisupervised mode not working"
314+
315+
# check everything runs well with log=True
316+
otda = ot.da.SinkhornTransport(log=True)
317+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
318+
assert len(otda.log_.keys()) != 0
319+
320+
248321
def test_emd_transport_class():
249322
"""test_sinkhorn_transport
250323
"""

0 commit comments

Comments
 (0)